diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..26809e1d8908f323344e17469a246fbcc2decb63 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/plywood-4k.jpg filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.hdf5 filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
+*.obj filter=lfs diff=lfs merge=lfs -text
+*.mtl filter=lfs diff=lfs merge=lfs -text
+*.stl filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.jpeg filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
+*.dae filter=lfs diff=lfs merge=lfs -text
+*.hdr filter=lfs diff=lfs merge=lfs -text
+*.msh filter=lfs diff=lfs merge=lfs -text
diff --git a/app.py b/app.py
index c779f46a9c5c1cbe85b2ca5999e182e9a3cde2f1..9091d56ef695d515a9b7e56c5045716ff02783ee 100644
--- a/app.py
+++ b/app.py
@@ -1,109 +1,427 @@
"""
-Phantom Video Processor - Hugging Face Space
+Phantom Video Processor - Hugging Face Space Demo
+将人类手部视频转换为机器人演示数据
"""
import gradio as gr
import spaces
import subprocess
import sys
+import os
+import shutil
+import tempfile
from pathlib import Path
-# ========== 环境配置 ==========
-
+# ========== 路径配置 ==========
PHANTOM_DIR = Path("/home/user/app/phantom")
+DATA_RAW_DIR = PHANTOM_DIR / "data" / "raw"
+DATA_PROCESSED_DIR = PHANTOM_DIR / "data" / "processed"
+MANO_DIR = PHANTOM_DIR / "submodules" / "phantom-hamer" / "_DATA" / "data" / "mano"
+
+# 添加 Phantom 到 Python 路径
+if PHANTOM_DIR.exists():
+ sys.path.insert(0, str(PHANTOM_DIR))
+ sys.path.insert(0, str(PHANTOM_DIR / "phantom"))
+
+# ========== 环境检测 ==========
+def check_environment():
+ """检查环境状态"""
+ status = {
+ "phantom_installed": Path("/tmp/.phantom_ready").exists(),
+ "mano_ready": (MANO_DIR / "MANO_LEFT.pkl").exists() and (MANO_DIR / "MANO_RIGHT.pkl").exists(),
+ "sample_data": (DATA_RAW_DIR / "pick_and_place").exists(),
+ "cuda_available": False,
+ "gpu_name": None
+ }
+
+ try:
+ import torch
+ status["cuda_available"] = torch.cuda.is_available()
+ if status["cuda_available"]:
+ status["gpu_name"] = torch.cuda.get_device_name(0)
+ except:
+ pass
+
+ return status
+
+def get_status_text():
+ """获取状态文本"""
+ status = check_environment()
+ lines = []
+ lines.append("=" * 40)
+ lines.append("环境状态")
+ lines.append("=" * 40)
+ lines.append(f"Phantom 安装: {'✅' if status['phantom_installed'] else '❌ 首次运行需初始化'}")
+ lines.append(f"MANO 模型: {'✅' if status['mano_ready'] else '❌ 请上传 MANO 模型文件'}")
+ lines.append(f"示例数据: {'✅' if status['sample_data'] else '⏳ 将自动下载'}")
+ lines.append(f"CUDA: {'✅ ' + (status['gpu_name'] or '') if status['cuda_available'] else '⏳ GPU 将在处理时分配'}")
+ lines.append("=" * 40)
+ return "\n".join(lines)
+
+# ========== MANO 模型上传 ==========
+def upload_mano_files(left_file, right_file):
+ """上传 MANO 模型文件"""
+ MANO_DIR.mkdir(parents=True, exist_ok=True)
+
+ messages = []
-def setup_environment():
- """配置Phantom环境(仅首次运行)"""
-
- # 检查是否已配置
+ if left_file is not None:
+ dest = MANO_DIR / "MANO_LEFT.pkl"
+ shutil.copy(left_file.name, dest)
+ messages.append(f"✅ MANO_LEFT.pkl 已保存")
+
+ if right_file is not None:
+ dest = MANO_DIR / "MANO_RIGHT.pkl"
+ shutil.copy(right_file.name, dest)
+ messages.append(f"✅ MANO_RIGHT.pkl 已保存")
+
+ if not messages:
+ return "⚠️ 请选择文件上传"
+
+ return "\n".join(messages) + "\n\n" + get_status_text()
+
+# ========== 初始化环境 ==========
+def initialize_environment(progress=gr.Progress()):
+ """初始化 Phantom 环境"""
if Path("/tmp/.phantom_ready").exists():
- print("✅ Phantom环境已配置")
- return True
-
- print("🔧 首次运行,配置环境(约5-10分钟)...")
-
- # 运行setup.sh
+ return "✅ 环境已就绪\n\n" + get_status_text()
+
+ progress(0, desc="开始初始化...")
+
setup_script = Path("/home/user/app/setup.sh")
- if setup_script.exists():
- try:
- result = subprocess.run(
- ["bash", str(setup_script)],
- check=True,
- capture_output=True,
- text=True
- )
- print(result.stdout)
- print("✅ 环境配置完成")
- return True
- except subprocess.CalledProcessError as e:
- print(f"❌ 配置失败: {e.stderr}")
- return False
- else:
- print("⚠️ setup.sh不存在")
- return False
+ if not setup_script.exists():
+ return "❌ setup.sh 不存在"
-# 添加Phantom到Python路径
-if PHANTOM_DIR.exists():
- sys.path.insert(0, str(PHANTOM_DIR))
+ try:
+ # 运行 setup.sh
+ progress(0.1, desc="运行安装脚本...")
+ process = subprocess.Popen(
+ ["bash", str(setup_script)],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ bufsize=1
+ )
+
+ output_lines = []
+ for line in iter(process.stdout.readline, ''):
+ output_lines.append(line.strip())
+ if len(output_lines) > 50:
+ output_lines = output_lines[-50:] # 保留最后 50 行
+
+ process.wait()
-# 启动时配置环境
-phantom_ready = setup_environment()
+ if process.returncode == 0:
+ progress(1.0, desc="完成!")
+ return "✅ 初始化完成!\n\n" + "\n".join(output_lines[-20:]) + "\n\n" + get_status_text()
+ else:
+ return f"❌ 初始化失败 (返回码: {process.returncode})\n\n" + "\n".join(output_lines[-30:])
-# ========== 其余代码保持不变 ==========
+ except Exception as e:
+ return f"❌ 初始化错误: {str(e)}"
-@spaces.GPU(duration=120)
-def process_video(video_file, robot_type, target_hand):
- """处理视频"""
+# ========== 视频处理 ==========
+@spaces.GPU(duration=300)
+def process_video(
+ video_file,
+ robot_type,
+ target_hand,
+ processing_mode,
+ use_sample_data,
+ progress=gr.Progress()
+):
+ """
+ 处理视频 - 将人类手部转换为机器人
+ """
import torch
- if video_file is None:
- return None, None, "请先上传视频"
+ # 状态信息
+ status_lines = []
- # 检查GPU
+ # GPU 检查
if torch.cuda.is_available():
gpu = torch.cuda.get_device_name(0)
- status = f"✅ GPU: {gpu}\n"
+ status_lines.append(f"✅ GPU: {gpu}")
+ status_lines.append(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
- status = "⚠️ 未检测到GPU\n"
-
- status += f"视频: {video_file}\n"
- status += f"机器人: {robot_type}\n"
- status += f"手部: {target_hand}\n"
-
- if not phantom_ready:
- status += "\n⚠️ Phantom环境未就绪"
-
- return None, None, status
-
-# Gradio界面
-with gr.Blocks(title="Phantom") as demo:
- gr.Markdown("# 🤖 Phantom - 机器人视频生成器")
-
- with gr.Row():
- with gr.Column():
- video_input = gr.Video(label="上传视频")
- robot_type = gr.Dropdown(
- choices=["Panda", "Kinova3", "UR5e"],
- value="Panda",
- label="机器人类型"
- )
- target_hand = gr.Radio(
- choices=["left", "right"],
- value="left",
- label="目标手部"
+ status_lines.append("❌ GPU 不可用")
+ return None, None, "\n".join(status_lines)
+
+ # 检查环境
+ if not Path("/tmp/.phantom_ready").exists():
+ status_lines.append("❌ 请先点击「初始化环境」按钮")
+ return None, None, "\n".join(status_lines)
+
+ # 检查 MANO
+ if not (MANO_DIR / "MANO_LEFT.pkl").exists():
+ status_lines.append("❌ 请先上传 MANO 模型文件")
+ return None, None, "\n".join(status_lines)
+
+ progress(0.1, desc="准备处理...")
+
+ # 确定输入数据
+ if use_sample_data:
+ demo_name = "pick_and_place"
+ data_root = str(DATA_RAW_DIR)
+ status_lines.append(f"📂 使用示例数据: {demo_name}")
+ else:
+ if video_file is None:
+ status_lines.append("❌ 请上传视频或选择使用示例数据")
+ return None, None, "\n".join(status_lines)
+
+ # 创建临时目录存放上传的视频
+ demo_name = "user_upload"
+ user_data_dir = DATA_RAW_DIR / demo_name / "0"
+ user_data_dir.mkdir(parents=True, exist_ok=True)
+
+ # 复制视频到正确位置
+ video_dest = user_data_dir / "video.mkv"
+ shutil.copy(video_file, video_dest)
+ data_root = str(DATA_RAW_DIR)
+ status_lines.append(f"📂 处理上传视频: {video_file}")
+
+ status_lines.append(f"🤖 机器人类型: {robot_type}")
+ status_lines.append(f"✋ 目标手部: {target_hand}")
+ status_lines.append(f"⚙️ 处理模式: {processing_mode}")
+ status_lines.append("-" * 40)
+
+ progress(0.2, desc="开始处理...")
+
+ # 构建处理命令
+ cmd = [
+ sys.executable,
+ str(PHANTOM_DIR / "phantom" / "process_data.py"),
+ f"demo_name={demo_name}",
+ f"data_root_dir={data_root}",
+ f"processed_data_root_dir={str(DATA_PROCESSED_DIR)}",
+ f"mode={processing_mode}",
+ f"robot={robot_type}",
+ f"target_hand={target_hand}",
+ "bimanual_setup=single_arm",
+ "demo_num=0", # 只处理第一个 demo
+ ]
+
+ status_lines.append(f"命令: {' '.join(cmd)}")
+
+ try:
+ # 运行处理
+ progress(0.3, desc="处理中...")
+
+ process = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ cwd=str(PHANTOM_DIR / "phantom"),
+ env={**os.environ, "PYTHONPATH": str(PHANTOM_DIR)}
+ )
+
+ output_lines = []
+ for line in iter(process.stdout.readline, ''):
+ line = line.strip()
+ if line:
+ output_lines.append(line)
+ # 更新进度
+ if "BBOX" in line:
+ progress(0.4, desc="检测边界框...")
+ elif "HAND2D" in line:
+ progress(0.5, desc="提取2D手部姿态...")
+ elif "SEGMENTATION" in line:
+ progress(0.6, desc="分割手臂...")
+ elif "ACTION" in line:
+ progress(0.7, desc="提取动作...")
+ elif "INPAINT" in line:
+ progress(0.8, desc="视频修复...")
+ elif "ROBOT" in line:
+ progress(0.9, desc="叠加机器人...")
+
+ process.wait()
+
+ progress(1.0, desc="完成!")
+
+ # 添加处理输出
+ status_lines.append("-" * 40)
+ status_lines.append("处理日志 (最后 20 行):")
+ status_lines.extend(output_lines[-20:])
+
+ # 查找输出文件
+ output_video = None
+ output_data = None
+
+ processed_dir = DATA_PROCESSED_DIR / demo_name / "0"
+
+ # 查找生成的视频
+ video_pattern = f"video_overlay_{robot_type}_single_arm.mkv"
+ for f in processed_dir.glob("**/*.mkv"):
+ if robot_type.lower() in f.name.lower():
+ output_video = str(f)
+ break
+
+ # 查找训练数据
+ for f in processed_dir.glob("**/training_data*.npz"):
+ output_data = str(f)
+ break
+
+ if output_video:
+ status_lines.append(f"\n✅ 输出视频: {output_video}")
+ if output_data:
+ status_lines.append(f"✅ 训练数据: {output_data}")
+
+ if process.returncode == 0:
+ status_lines.insert(0, "✅ 处理完成!")
+ else:
+ status_lines.insert(0, f"⚠️ 处理完成但有警告 (返回码: {process.returncode})")
+
+ return output_video, output_data, "\n".join(status_lines)
+
+ except Exception as e:
+ import traceback
+ status_lines.append(f"\n❌ 处理错误: {str(e)}")
+ status_lines.append(traceback.format_exc())
+ return None, None, "\n".join(status_lines)
+
+# ========== Gradio 界面 ==========
+with gr.Blocks(
+ title="Phantom - 机器人视频生成器",
+ theme=gr.themes.Soft()
+) as demo:
+
+ gr.Markdown("""
+ # 🤖 Phantom - 将人类视频转换为机器人演示
+
+ **论文**: [Phantom: Training Robots Without Robots Using Only Human Videos](https://phantom-human-videos.github.io/)
+
+ 将人类手部操作视频自动转换为机器人演示数据,用于训练机器人策略。
+ """)
+
+ with gr.Tabs():
+ # ========== 环境设置 Tab ==========
+ with gr.TabItem("1️⃣ 环境设置"):
+ gr.Markdown("""
+ ### 首次使用需要完成以下步骤:
+
+ 1. **初始化环境** - 安装依赖和下载模型 (首次约 5-10 分钟)
+ 2. **上传 MANO 模型** - 需要从官网注册下载
+ """)
+
+ with gr.Row():
+ with gr.Column():
+ init_btn = gr.Button("🔧 初始化环境", variant="primary", size="lg")
+ init_output = gr.Textbox(
+ label="初始化状态",
+ lines=15,
+ value=get_status_text()
+ )
+
+ with gr.Column():
+ gr.Markdown("""
+ ### MANO 模型下载
+
+ 1. 访问 [MANO 官网](https://mano.is.tue.mpg.de/)
+ 2. 注册账号并下载模型
+ 3. 上传 `MANO_LEFT.pkl` 和 `MANO_RIGHT.pkl`
+ """)
+
+ mano_left = gr.File(label="MANO_LEFT.pkl", file_types=[".pkl"])
+ mano_right = gr.File(label="MANO_RIGHT.pkl", file_types=[".pkl"])
+ upload_btn = gr.Button("📤 上传 MANO 模型")
+ upload_output = gr.Textbox(label="上传状态", lines=5)
+
+ init_btn.click(fn=initialize_environment, outputs=init_output)
+ upload_btn.click(fn=upload_mano_files, inputs=[mano_left, mano_right], outputs=upload_output)
+
+ # ========== 视频处理 Tab ==========
+ with gr.TabItem("2️⃣ 视频处理"):
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("### 输入设置")
+
+ use_sample = gr.Checkbox(
+ label="使用示例数据 (pick_and_place)",
+ value=True,
+ info="推荐首次使用时勾选,使用预置的示例视频"
+ )
+
+ video_input = gr.Video(
+ label="或上传自己的视频",
+ interactive=True
+ )
+
+ robot_type = gr.Dropdown(
+ choices=["Panda", "Kinova3", "UR5e", "IIWA", "Jaco"],
+ value="Panda",
+ label="机器人类型"
+ )
+
+ target_hand = gr.Radio(
+ choices=["left", "right"],
+ value="left",
+ label="目标手部"
+ )
+
+ processing_mode = gr.Dropdown(
+ choices=[
+ "bbox",
+ "hand2d",
+ "arm_segmentation",
+ "hand_inpaint",
+ "robot_inpaint",
+ "all"
+ ],
+ value="bbox",
+ label="处理模式",
+ info="建议逐步运行: bbox -> hand2d -> arm_segmentation -> hand_inpaint -> robot_inpaint"
+ )
+
+ process_btn = gr.Button("🚀 开始处理", variant="primary", size="lg")
+
+ with gr.Column():
+ gr.Markdown("### 输出结果")
+
+ video_output = gr.Video(label="生成的机器人视频")
+ data_output = gr.File(label="训练数据 (NPZ)")
+ status_output = gr.Textbox(label="处理状态", lines=20)
+
+ process_btn.click(
+ fn=process_video,
+ inputs=[video_input, robot_type, target_hand, processing_mode, use_sample],
+ outputs=[video_output, data_output, status_output]
)
- btn = gr.Button("开始处理", variant="primary")
- with gr.Column():
- video_out = gr.Video(label="结果视频")
- data_out = gr.File(label="训练数据")
- status_out = gr.Textbox(label="状态", lines=10)
+ # ========== 说明 Tab ==========
+ with gr.TabItem("📖 说明"):
+ gr.Markdown("""
+ ## 处理流程
+
+ Phantom 将人类手部视频转换为机器人演示数据,处理步骤:
+
+ | 步骤 | 模式 | 描述 |
+ |------|------|------|
+ | 1 | `bbox` | 检测手部边界框 |
+ | 2 | `hand2d` | 提取 2D 手部姿态 |
+ | 3 | `arm_segmentation` | 分割人类手臂 |
+ | 4 | `hand_inpaint` | 移除手臂并修复背景 |
+ | 5 | `robot_inpaint` | 叠加虚拟机器人 |
+
+ ## 输入要求
+
+ - **视频格式**: MKV, MP4 等常见格式
+ - **分辨率**: 推荐 1080p
+ - **内容**: 单手操作视频,手部需清晰可见
+
+ ## GPU Zero 限制
+
+ - 单次处理时间限制: 300 秒
+ - 建议逐步运行各处理模式
+ - 复杂视频可能需要多次处理
+
+ ## 参考资料
- btn.click(
- fn=process_video,
- inputs=[video_input, robot_type, target_hand],
- outputs=[video_out, data_out, status_out]
- )
+ - [Phantom 论文](https://arxiv.org/abs/2503.00779)
+ - [GitHub 仓库](https://github.com/MarionLepert/phantom)
+ - [MANO 手部模型](https://mano.is.tue.mpg.de/)
+ """)
+# 启动
if __name__ == "__main__":
demo.queue().launch()
diff --git a/phantom b/phantom
deleted file mode 160000
index a8bb81c1bbe6ade129a1f6f0906482f510354a5e..0000000000000000000000000000000000000000
--- a/phantom
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit a8bb81c1bbe6ade129a1f6f0906482f510354a5e
diff --git a/phantom/.gitignore b/phantom/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..9d4c4b47f97936f015d1e9223c36e186b205383d
--- /dev/null
+++ b/phantom/.gitignore
@@ -0,0 +1,11 @@
+*.egg-info
+**/_DATA/*
+data/raw/*
+!data/raw/.gitkeep
+data/processed/*
+!data/processed/.gitkeep
+**/__pycache__/*
+*.pyc
+*.pth
+outputs/*
+phantom/outputs/*
\ No newline at end of file
diff --git a/phantom/.gitmodules b/phantom/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..f965f10b5541f50eba9f54f32884677fd641b8ea
--- /dev/null
+++ b/phantom/.gitmodules
@@ -0,0 +1,15 @@
+[submodule "submodules/phantom-E2FGVI"]
+ path = submodules/phantom-E2FGVI
+ url = git@github.com:MarionLepert/phantom-E2FGVI.git
+[submodule "submodules/sam2"]
+ path = submodules/sam2
+ url = git@github.com:facebookresearch/sam2.git
+[submodule "submodules/phantom-robosuite"]
+ path = submodules/phantom-robosuite
+ url = git@github.com:MarionLepert/phantom-robosuite.git
+[submodule "submodules/phantom-robomimic"]
+ path = submodules/phantom-robomimic
+ url = git@github.com:MarionLepert/phantom-robomimic.git
+[submodule "submodules/phantom-hamer"]
+ path = submodules/phantom-hamer
+ url = git@github.com:MarionLepert/phantom-hamer.git
diff --git a/phantom/LICENSE b/phantom/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..a7919af83c0ef65fa3d553a06db8f0f491fd7cba
--- /dev/null
+++ b/phantom/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 Stanford Interactive Perception and Robot Learning Lab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/phantom/README.md b/phantom/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cad6c69ba6b73a8ce9c87ec429c0ba4d25f6150e
--- /dev/null
+++ b/phantom/README.md
@@ -0,0 +1,168 @@
+# Code for Phantom and Masquerade
+[](https://www.python.org)
+[](https://opensource.org/licenses/MIT)
+
+
+This repository contains the code used to process human videos in [Phantom: Training Robots Without Robots Using Only Human Videos](https://phantom-human-videos.github.io/) and [Masquerade: Learning from In-the-wild Human Videos using Data-Editing](https://masquerade-robot.github.io/).
+
+
+
+Both projects use data editing to convert human videos into “robotized” demonstrations. They share much of the same codebase, with some differences in the processing pipeline:
+
+**Phantom**
+* Input: RGBD videos with a single left hand visible in every frame.
+* Data editing: inpaint the single human arm, overlay a rendered robot arm in the same pose.
+* Action labels: extract full 3D end-effector pose (position, orientation, gripper)
+
+**Masquerade**
+* Input: RGB videos from [Epic Kitchens](https://epic-kitchens.github.io/2025); one or both hands may be visible, sometimes occluded.
+* Data editing: segment and inpaint both arms, overlay a bimanual robot whose effectors follow the estimated poses (with a 3-4cm error along the depth direction due to lack of depth data)
+* Action labels: use 2D projected waypoints as auxiliary supervision only (not full 3D actions)
+
+
+
+## Installation
+1. Clone this repo recursively
+
+```bash
+git clone --recursive git@github.com:MarionLepert/phantom.git
+```
+
+2. Run the following script from the root directory to install the required conda environment.
+```bash
+./install.sh
+```
+
+3. Download the MANO hand models. To do so, go to the [MANO website](https://mano.is.tue.mpg.de/) and register to be able to download the models. Download the left and right hand models and move MANO_LEFT.pkl and MANO_RIGHT.pkl inside the `$ROOT_DIR/submodules/phantom-hamer/_DATA/data/mano/` folder.
+
+## Getting Started
+Process **Phantom** sample data (manually collected in-lab videos)
+```bash
+conda activate phantom
+
+python process_data.py demo_name=pick_and_place data_root_dir=../data/raw processed_data_root_dir=../data/processed mode=all
+```
+
+Process **Masquerade** sample data ([Epic Kitchens](https://epic-kitchens.github.io/2025) video)
+```bash
+conda activate phantom
+
+python process_data.py demo_name=epic data_root_dir=../data/raw processed_data_root_dir=../data/processed mode=all --config-name=epic
+```
+
+
+## Codebase Overview
+
+### Process data
+Each video is processed using the following steps:
+
+1. **Extract human hand bounding boxes**: `bbox_processor.py`
+ * `mode=bbox`
+
+2. **Extract 2d human hand poses**: `hand_processor.py`
+ * `mode=hand2d`: extract the 2d hand pose
+
+3. **Extract human and arm segmentation masks**: `segmentation_processor.py`
+ * `mode=hand_segmentation`: used for depth alignment in hand pose refinement (only works for hand3d)
+ * `mode=arm_segmentation`: needed in all cases to inpaint the human
+
+2. **Extract 3d human hand poses**: `hand_processor.py`
+ * `mode=hand3d`: extract the 3d hand pose (note: requires depth, and was only tested on the left hand)
+
+4. **Retarget human actions to robot actions**: `action_processor.py`
+ * `mode=action`
+
+5. **Smooth human poses**: `smoothing_processor.py`
+ * `mode=smoothing`
+
+6. **Remove hand from videos using inpainting**: `handinpaint_processor.py`
+ * `mode=hand_inpaint`
+ * Inpainting method [E2FGVI](https://arxiv.org/pdf/2204.02663) is used.
+
+7. **Overlay virtual robot on video**: `robotinpaint_processor.py`
+ * `mode=robot_inpaint`: overlay a single robot (default) or bimanual (epic mode) robot on the image
+
+
+### Config reference (see configuration files in `configs/`)
+
+| Flag | Type | Required | Choices | Description |
+|------|------|----------|---------|-------------|
+| `--demo_name` | `str` | ✅ | - | Name of the demonstration/dataset to process |
+| `--mode` | `str` (multiple) | ✅ | `bbox`, `hand2d`, `hand3d`, `hand_segmentation`, `arm_segmentation`, `action`, `smoothing`, `hand_inpaint`, `robot_inpaint`, `all` | Processing modes to run (can specify multiple with e.g. `'mode=[bbox,hand2d]'`) |
+| `--robot_name` | `str` | ✅ | `Panda`, `Kinova3`, `UR5e`, `IIWA`, `Jaco` | Type of robot to use for overlays |
+| `--gripper_name` | `str` | ❌ | `Robotiq85` | Type of gripper to use |
+| `--data_root_dir` | `str` | ❌ | - | Root directory containing raw video data |
+| `--processed_data_root_dir` | `str` | ❌ | - | Root directory to save processed data |
+| `--epic` | `bool` | ❌ | - | Use Epic-Kitchens dataset processing mode |
+| `--bimanual_setup` | `str` | ❌ | `single_arm`, `shoulders` | Bimanual setup configuration to use (shoulders corresponds to the bimanual hardware configuration used in Masquerade) |
+| `--target_hand` | `str` | ❌ | `left`, `right`, `both` | Which hand(s) to target for processing |
+| `--camera_intrinsics` | `str` | ❌ | - | Path to camera intrinsics file |
+| `--camera_extrinsics` | `str` | ❌ | - | Path to camera extrinsics file |
+| `--input_resolution` | `int` | ❌ | - | Resolution of input videos |
+| `--output_resolution` | `int` | ❌ | - | Resolution of output videos |
+| `--depth_for_overlay` | `bool` | ❌ | - | Use depth information for overlays |
+| `--demo_num` | `str` | ❌ | - | Process a single demo number instead of all demos |
+| `--debug_cameras` | `str` (multiple) | ❌ | - | Additional camera names to include for debugging |
+| `--constrained_hand` | `bool` | ❌ | - | Use constrained hand processing |
+| `--render` | `bool` | ❌ | - | Render the robot overlay on the video |
+
+**Note** Please specify `--bimanual_setup single_arm` along with `--target_hand left` or `--target_hand right` if you are using single arm. For bimanual setups, use `--bimanual_setup shoulders`.
+
+### Camera details
+* **Phantom**: a Zed2 camera was used to capture the sample data at HD1080 resolution.
+* **Masquerade**: We used Epic-Kitchens videos and used the camera intrinsics provided in the dataset. To use videos captured with a different camera resolution, update the camera intrinsics and extrinsics files in `$ROOT_DIR/phantom/camera/`.
+
+### Train policy
+After processing the video data, the edited data can be used to train a policy. The following files should be used:
+
+* Observations
+ * Phantom Samples: extract RGB images from `data/processed/pick_and_place/*/video_overlay_Panda_single_arm.mkv`
+ * Epic (In-the-wild Data) Samples: extract RGB images from `data/processed/epic/*/video_overlay_Kinova3_shoulders.mkv`
+
+* Actions
+ * Phantom Samples: All data stored in `data/processed/pick_and_place/*/inpaint_processor/training_data_single_arm.npz`
+ * Epic (In-the-wild Data) Samples: All data stored in `data/processed/epic/*/inpaint_processor/training_data_shoulders.npz`
+
+
+In Phantom, [Diffusion Policy](https://github.com/real-stanford/diffusion_policy) was used for policy training.
+
+
+## Citation
+```bibtex
+@article{lepert2025phantomtrainingrobotsrobots,
+ title={Phantom: Training Robots Without Robots Using Only Human Videos},
+ author={Marion Lepert and Jiaying Fang and Jeannette Bohg},
+ year={2025},
+ eprint={2503.00779},
+ archivePrefix={arXiv},
+ primaryClass={cs.RO},
+ url={https://arxiv.org/abs/2503.00779},
+ }
+```
+
+```bibtex
+@misc{lepert2025masqueradelearninginthewildhuman,
+ title={Masquerade: Learning from In-the-wild Human Videos using Data-Editing},
+ author={Marion Lepert and Jiaying Fang and Jeannette Bohg},
+ year={2025},
+ eprint={2508.09976},
+ archivePrefix={arXiv},
+ primaryClass={cs.RO},
+ url={https://arxiv.org/abs/2508.09976},
+}
+```
diff --git a/phantom/configs/default.yaml b/phantom/configs/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7de4caa5a2357a1b8466b9d079b14c4fd093e1e6
--- /dev/null
+++ b/phantom/configs/default.yaml
@@ -0,0 +1,30 @@
+# Default configuration (PHANTOM paper settings)
+debug: false
+verbose: false
+skip_existing: false
+n_processes: 1
+data_root_dir: "../data/raw_data/"
+processed_data_root_dir: "../data/processed_data/"
+demo_name: ""
+
+# Processing settings
+mode: ["bbox"] # Default processing mode - must be one of: bbox, hand2d, hand3d, hand_segmentation, arm_segmentation, action, smoothing, hand_inpaint, robot_inpaint, all
+demo_num: null # Process specific demo number (null = process all)
+
+# Additional settings
+debug_cameras: []
+
+# PHANTOM paper configuration (default)
+input_resolution: 1080
+output_resolution: 240
+robot: "Panda"
+gripper: "Robotiq85"
+square: true
+epic: false
+bimanual_setup: "single_arm"
+target_hand: "left"
+constrained_hand: true
+depth_for_overlay: true
+render: false
+camera_intrinsics: "camera/camera_intrinsics_HD1080.json"
+camera_extrinsics: "camera/camera_extrinsics.json"
diff --git a/phantom/configs/epic.yaml b/phantom/configs/epic.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cbcd86f3122a51cdd79f52027fdabe579755abc1
--- /dev/null
+++ b/phantom/configs/epic.yaml
@@ -0,0 +1,31 @@
+# Default configuration (PHANTOM paper settings)
+debug: false
+verbose: false
+skip_existing: false
+n_processes: 1
+data_root_dir: "../data/raw_data/"
+processed_data_root_dir: "../data/processed_data/"
+demo_name: ""
+
+# Processing settings
+mode: ["bbox"] # Default processing mode
+demo_num: null # Process specific demo number (null = process all videos in the root folder)
+
+# Additional settings
+debug_cameras: [] # Add other robomimic cameras like sideview, etc. Warning: this significantly slows down the processing time
+
+
+# EPIC-KITCHENS configuration override
+input_resolution: 256
+output_resolution: 256
+robot: "Kinova3"
+gripper: "Robotiq85"
+square: false
+epic: true
+bimanual_setup: "shoulders"
+target_hand: "both"
+constrained_hand: false
+depth_for_overlay: false
+render: false
+camera_intrinsics: "camera/camera_intrinsics_epic.json"
+camera_extrinsics: "camera/camera_extrinsics_ego_bimanual_shoulders.json"
diff --git a/phantom/configs/sam2_hiera_l.yaml b/phantom/configs/sam2_hiera_l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1092802b1d24be6fedf78939f45b0d021d4ec560
--- /dev/null
+++ b/phantom/configs/sam2_hiera_l.yaml
@@ -0,0 +1,117 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 144
+ num_heads: 2
+ stages: [2, 6, 36, 4]
+ global_att_blocks: [23, 33, 43]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ window_spec: [8, 4, 16, 8]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [1152, 576, 288, 144]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/phantom/data/__init__.py b/phantom/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/docs/teaser_masquerade.png b/phantom/docs/teaser_masquerade.png
new file mode 100644
index 0000000000000000000000000000000000000000..821d8082dbe9ffbb0a2a2a21e3584fa204932b7e
--- /dev/null
+++ b/phantom/docs/teaser_masquerade.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f0f5355b51b44f98b8aced3b5c41255d3e9a04b0810a4d9b616c67e1ba05b9c
+size 1278978
diff --git a/phantom/docs/teaser_phantom.png b/phantom/docs/teaser_phantom.png
new file mode 100644
index 0000000000000000000000000000000000000000..fdef26e45b26c5a990313f7aa2c73374c7edba34
--- /dev/null
+++ b/phantom/docs/teaser_phantom.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a79506ef23efac9c85af0805ca5e23ec59a6a90e0de7bc475cfde94bd793f9c0
+size 3089124
diff --git a/phantom/install.sh b/phantom/install.sh
new file mode 100755
index 0000000000000000000000000000000000000000..f98f34492a025c1bd46fe9a46d6bce65ca2ed12f
--- /dev/null
+++ b/phantom/install.sh
@@ -0,0 +1,67 @@
+eval "$(conda shell.bash hook)"
+# ######################## Phantom Env ###############################
+conda create -n phantom python=3.10 -y
+conda activate phantom
+conda install nvidia/label/cuda-12.1.0::cuda-toolkit -c nvidia/label/cuda-12.1.0 -y
+
+# Install SAM2
+cd submodules/sam2
+pip install -v -e ".[notebooks]"
+cd ../..
+
+# Install Hamer
+cd submodules/phantom-hamer
+pip install -e .\[all\]
+pip install -v -e third-party/ViTPose
+wget https://www.cs.utexas.edu/~pavlakos/hamer/data/hamer_demo_data.tar.gz
+tar --warning=no-unknown-keyword --exclude=".*" -xvf hamer_demo_data.tar.gz
+cd ../..
+
+# Install mmcv
+pip install --index-url https://download.pytorch.org/whl/cu121 torch==2.1.0 torchvision==0.16.0
+pip install mmcv==1.3.9
+pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html
+pip install numpy==1.26.4
+
+# Install phantom-robosuite
+cd submodules/phantom-robosuite
+pip install -e .
+cd ../..
+
+# Install phantom-robomimic
+cd submodules/phantom-robomimic
+pip install -e .
+cd ../..
+
+# Install additional packages
+pip install joblib mediapy open3d pandas
+pip install transformers==4.42.4
+pip install PyOpenGL==3.1.4
+pip install Rtree
+pip install git+https://github.com/epic-kitchens/epic-kitchens-100-hand-object-bboxes.git
+pip install protobuf==3.20.0
+pip install hydra-core==1.3.2
+pip install omegaconf==2.3.0
+
+# Download E2FGVI weights
+cd submodules/phantom-E2FGVI/E2FGVI/release_model/
+pip install gdown
+gdown --fuzzy https://drive.google.com/file/d/10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3/view?usp=sharing
+cd ../..
+
+# Install phantom-E2FGVI
+pip install -e .
+cd ../..
+
+# Install phantom
+pip install -e .
+
+# Download sample data
+cd data/raw
+wget https://download.cs.stanford.edu/juno/phantom/pick_and_place.zip
+unzip pick_and_place.zip
+rm pick_and_place.zip
+wget https://download.cs.stanford.edu/juno/phantom/epic.zip
+unzip epic.zip
+rm epic.zip
+cd ../..
diff --git a/phantom/phantom/__init__.py b/phantom/phantom/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/phantom/camera/__init__.py b/phantom/phantom/camera/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/phantom/camera/camera_extrinsics.json b/phantom/phantom/camera/camera_extrinsics.json
new file mode 100644
index 0000000000000000000000000000000000000000..002325929b83dabaae9357cf13f684994f5a200a
--- /dev/null
+++ b/phantom/phantom/camera/camera_extrinsics.json
@@ -0,0 +1,42 @@
+[
+ {
+ "camera_base_ori": [
+ [
+ 0.9842690634302423,
+ -0.053375086066005106,
+ 0.1684206369825258
+ ],
+ [
+ -0.1763762231197722,
+ -0.35235905397979306,
+ 0.9190944048336218
+ ],
+ [
+ 0.010287793357058851,
+ -0.934341584895969,
+ -0.3562302121408726
+ ]
+ ],
+ "camera_base_ori_rotvec": [
+ -1.930138005212092,
+ 0.16467696378244215,
+ -0.12809137765065973
+ ],
+ "camera_base_pos": [
+ 0.3407932803063093,
+ -0.40868423448040403,
+ 0.39911982578151795
+ ],
+ "camera_base_quat": [
+ 0.8204965462375373,
+ -0.07000374049084156,
+ 0.054451304871138306,
+ -0.564729979129313
+ ],
+ "p_marker_ee": [
+ -0.01874144739551215,
+ 0.029611448317719172,
+ -0.013687685723932594
+ ]
+ }
+]
\ No newline at end of file
diff --git a/phantom/phantom/camera/camera_extrinsics_ego_bimanual_shoulders.json b/phantom/phantom/camera/camera_extrinsics_ego_bimanual_shoulders.json
new file mode 100644
index 0000000000000000000000000000000000000000..88a9a82435a8d7b9b3f3be32671ff1c1fa8f1573
--- /dev/null
+++ b/phantom/phantom/camera/camera_extrinsics_ego_bimanual_shoulders.json
@@ -0,0 +1,52 @@
+[
+ {
+ "num_marker_seen": 114,
+ "stage2_retry": 11,
+ "pixel_error": 2.1157278874907863,
+ "proj_func": "hand_marker_proj_world_camera",
+ "intrinsics": {
+ "fx": 731.4708862304688,
+ "fy": 731.4708862304688,
+ "ppx": 646.266357421875,
+ "ppy": 355.9967956542969
+ },
+ "camera_base_ori": [
+ [
+ -0.7220417114840215,
+ 0.37764981440725887,
+ 0.579686453658689
+ ],
+ [
+ 0.020370475586732495,
+ 0.8491206965938227,
+ -0.527805917303316
+ ],
+ [
+ -0.6915495720493177,
+ -0.3692893991088662,
+ -0.6207934673498243
+ ]
+ ],
+ "camera_base_ori_rotvec": [
+ 0.2877344548443808,
+ 2.3075097094104504,
+ -0.6485227972051454
+ ],
+ "camera_base_pos": [
+ -0.5123627783256401,
+ -0.11387480700266536,
+ 0.3151264229148423
+ ],
+ "p_marker_ee": [
+ -0.041990731174163416,
+ -0.02636865486252487,
+ -0.01442948433864288
+ ],
+ "camera_base_quat": [
+ 0.11139014686225811,
+ 0.8933022830245745,
+ -0.25106152012025673,
+ 0.35576871621882866
+ ]
+ }
+]
\ No newline at end of file
diff --git a/phantom/phantom/camera/camera_intrinsics_HD1080.json b/phantom/phantom/camera/camera_intrinsics_HD1080.json
new file mode 100644
index 0000000000000000000000000000000000000000..ff52d7fe21d3d1b5bc1b078e3ac4e5ba0292f327
--- /dev/null
+++ b/phantom/phantom/camera/camera_intrinsics_HD1080.json
@@ -0,0 +1,48 @@
+{
+ "left": {
+ "fx": 1057.7322998046875,
+ "fy": 1057.7322998046875,
+ "cx": 972.5150756835938,
+ "cy": 552.568359375,
+ "disto": [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "v_fov": 54.09259796142578,
+ "h_fov": 84.45639038085938,
+ "d_fov": 92.32276916503906
+ },
+ "right": {
+ "fx": 1057.7322998046875,
+ "fy": 1057.7322998046875,
+ "cx": 972.5150756835938,
+ "cy": 552.568359375,
+ "disto": [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "v_fov": 54.09259796142578,
+ "h_fov": 84.45639038085938,
+ "d_fov": 92.32276916503906
+ }
+}
\ No newline at end of file
diff --git a/phantom/phantom/camera/camera_intrinsics_epic.json b/phantom/phantom/camera/camera_intrinsics_epic.json
new file mode 100644
index 0000000000000000000000000000000000000000..29986434a940212bca3f49df336bfeb3520a839a
--- /dev/null
+++ b/phantom/phantom/camera/camera_intrinsics_epic.json
@@ -0,0 +1,48 @@
+{
+ "left": {
+ "fx": 248.7892127911359,
+ "fy": 248.7892127911359,
+ "cx": 228,
+ "cy": 128,
+ "disto": [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "v_fov": 54.6,
+ "h_fov": 83.21271514892578,
+ "d_fov": 91.07240295410156
+ },
+ "right": {
+ "fx": 248.7892127911359,
+ "fy": 248.7892127911359,
+ "cx": 228,
+ "cy": 128,
+ "disto": [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "v_fov": 54.6,
+ "h_fov": 83.21271514892578,
+ "d_fov": 91.07240295410156
+ }
+}
\ No newline at end of file
diff --git a/phantom/phantom/detectors/detector_detectron2.py b/phantom/phantom/detectors/detector_detectron2.py
new file mode 100644
index 0000000000000000000000000000000000000000..608dd61d900a476a51ddc0285afd503ffa753047
--- /dev/null
+++ b/phantom/phantom/detectors/detector_detectron2.py
@@ -0,0 +1,121 @@
+"""
+Wrapper around detectron2 for object detection
+"""
+import os
+import numpy as np
+from pathlib import Path
+from typing import Tuple
+import cv2
+import logging
+import mediapy as media
+import requests
+import hamer # type: ignore
+from hamer.utils.utils_detectron2 import DefaultPredictor_Lazy # type: ignore
+from detectron2.config import LazyConfig # type: ignore
+
+logger = logging.getLogger(__name__)
+
+def download_detectron_ckpt(root_dir: str, ckpt_path: str) -> None:
+ url = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl"
+ save_path = Path(root_dir, ckpt_path)
+ save_path.parent.mkdir(exist_ok=True, parents=True)
+ response = requests.get(url, stream=True)
+ if response.status_code == 200:
+ with open(save_path, "wb") as file:
+ for chunk in response.iter_content(chunk_size=8192):
+ file.write(chunk)
+ logger.info(f"File downloaded successfully and saved to {save_path}")
+ else:
+ logger.info(f"Failed to download the file. Status code: {response.status_code}")
+
+
+class DetectorDetectron2:
+ def __init__(self, root_dir: str):
+ cfg_path = (Path(hamer.__file__).parent / "configs" / "cascade_mask_rcnn_vitdet_h_75ep.py")
+ detectron2_cfg = LazyConfig.load(str(cfg_path))
+
+ detectron2_cfg.train.init_checkpoint = os.path.join(
+ root_dir, "_DATA/detectron_ckpts/model_final_f05665.pkl"
+ )
+ if not os.path.exists(detectron2_cfg.train.init_checkpoint):
+ download_detectron_ckpt(
+ root_dir, "_DATA/detectron_ckpts/model_final_f05665.pkl"
+ )
+ for predictor in detectron2_cfg.model.roi_heads.box_predictors:
+ predictor.test_score_thresh = 0.25
+ self.detectron2 = DefaultPredictor_Lazy(detectron2_cfg)
+
+ def get_bboxes(self, img: np.ndarray, visualize: bool=False,
+ visualize_wait: bool=True) -> Tuple[np.ndarray, np.ndarray]:
+ """ Get bounding boxes and scores for the detected hand in the image """
+ det_out = self.detectron2(img)
+
+ det_instances = det_out["instances"]
+ valid_idx = (det_instances.pred_classes == 0) & (det_instances.scores > 0.5)
+ pred_bboxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
+ pred_scores = det_instances.scores[valid_idx].cpu().numpy()
+
+ if visualize:
+ img_rgb = img.copy()
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ for bbox, score in zip(pred_bboxes, pred_scores):
+ cv2.rectangle(
+ img_bgr,
+ (int(bbox[0]), int(bbox[1])),
+ (int(bbox[2]), int(bbox[3])),
+ (0, 255, 0),
+ 2,
+ )
+ cv2.putText(img_bgr,
+ f"{score:.4f}",
+ (int(bbox[0]), int(bbox[1])),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 1,
+ (0, 255, 0),
+ 2,
+ cv2.LINE_AA)
+
+ cv2.imshow(f"Detected bounding boxes", img_bgr)
+ if visualize_wait:
+ cv2.waitKey(0)
+ else:
+ cv2.waitKey(1)
+
+ return pred_bboxes, pred_scores
+
+ def get_best_bbox(self, img: np.ndarray, visualize: bool=False,
+ visualize_wait: bool=True) -> Tuple[np.ndarray, float]:
+ """ Get the best bounding box and score for the detected hand in the image """
+ bboxes, scores = self.get_bboxes(img)
+ if len(bboxes) == 0:
+ logger.info("No bbox found with Detectron")
+ return np.array([]), 0
+ best_idx = scores.argmax()
+ best_bbox, best_score = bboxes[best_idx], scores[best_idx]
+
+ if visualize:
+ img_rgb = img.copy()
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ cv2.rectangle(
+ img_bgr,
+ (int(best_bbox[0]), int(best_bbox[1])),
+ (int(best_bbox[2]), int(best_bbox[3])),
+ (0, 255, 0),
+ 2,
+ )
+ cv2.putText(img_bgr,
+ f"{best_score:.4f}",
+ (int(best_bbox[0]), int(best_bbox[1])),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 1,
+ (0, 255, 0),
+ 2,
+ cv2.LINE_AA)
+
+ cv2.imshow(f"Best detected bounding box", img_bgr)
+ if visualize_wait:
+ cv2.waitKey(0)
+ else:
+ cv2.waitKey(1)
+
+ return best_bbox, best_score
\ No newline at end of file
diff --git a/phantom/phantom/detectors/detector_dino.py b/phantom/phantom/detectors/detector_dino.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ad8fa9162d5545ebd9f9e6bacf434e56a24e37a
--- /dev/null
+++ b/phantom/phantom/detectors/detector_dino.py
@@ -0,0 +1,108 @@
+"""
+Wrapper around DINO-V2 for object detection
+"""
+from typing import Sequence, Tuple, Optional
+import numpy as np
+from transformers import pipeline # type: ignore
+from PIL import Image
+import cv2
+import logging
+
+from phantom.utils.image_utils import DetectionResult
+
+logger = logging.getLogger(__name__)
+
+class DetectorDino:
+ def __init__(self, detector_id: str):
+ self.detector = pipeline(
+ model=detector_id,
+ task="zero-shot-object-detection",
+ device="cuda",
+ batch_size=4,
+ )
+
+ def get_bboxes(self, frame: np.ndarray, object_name: str, threshold: float = 0.4,
+ visualize: bool = False, pause_visualization: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Detect objects in a frame and return their bounding boxes and confidence scores.
+
+ Args:
+ frame: Input image as numpy array in RGB format
+ object_name: Target object category to detect
+ threshold: Confidence threshold for detection (0.0-1.0)
+ visualize: If True, displays detection results visually
+ pause_visualization: If True, waits for key press when visualizing
+
+ Returns:
+ Tuple of (bounding_boxes, confidence_scores) as numpy arrays
+ Empty arrays if no objects detected
+ """
+ img_pil = Image.fromarray(frame)
+ labels = [f"{object_name}."]
+ results = self.detector(img_pil, candidate_labels=labels, threshold=threshold)
+ results = [DetectionResult.from_dict(result) for result in results]
+ if not results:
+ return np.array([]), np.array([])
+ bboxes = np.array([np.array(result.box.xyxy) for result in results])
+ scores = np.array([result.score for result in results])
+
+ if visualize:
+ img_rgb = frame.copy()
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ for bbox, score in zip(bboxes, scores):
+ cv2.rectangle(
+ img_bgr,
+ (int(bbox[0]), int(bbox[1])),
+ (int(bbox[2]), int(bbox[3])),
+ (0, 255, 0),
+ 2,
+ )
+ cv2.putText(img_bgr,
+ f"{score:.4f}",
+ (int(bbox[0]), int(bbox[1])),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 1,
+ (0, 255, 0),
+ 2,
+ cv2.LINE_AA)
+ cv2.imshow("Detection", img_bgr)
+ if pause_visualization:
+ cv2.waitKey(0)
+ else:
+ cv2.waitKey(1)
+ return bboxes, scores
+
+
+ def get_best_bbox(self, frame: np.ndarray, object_name: str, threshold: float = 0.4,
+ visualize: bool = False, pause_visualization: bool = True) -> Optional[np.ndarray]:
+ bboxes, scores = self.get_bboxes(frame, object_name, threshold)
+ if len(bboxes) == 0:
+ return None
+ best_idx = np.array(scores).argmax()
+ best_bbox, best_score = bboxes[best_idx], scores[best_idx]
+
+ if visualize:
+ img_rgb = frame.copy()
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ cv2.rectangle(
+ img_bgr,
+ (best_bbox[0], best_bbox[1]),
+ (best_bbox[2], best_bbox[3]),
+ (0, 255, 0),
+ 2,
+ )
+ cv2.putText(img_bgr,
+ f"{best_score:.4f}",
+ (int(best_bbox[0]), int(best_bbox[1])),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 1,
+ (0, 255, 0),
+ 2,
+ cv2.LINE_AA)
+ cv2.imshow("Detection", img_bgr)
+ if pause_visualization:
+ cv2.waitKey(0)
+ else:
+ cv2.waitKey(1)
+ return best_bbox
+
\ No newline at end of file
diff --git a/phantom/phantom/detectors/detector_hamer.py b/phantom/phantom/detectors/detector_hamer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc6f5143255df1a1b9f6086854266d86a53278a9
--- /dev/null
+++ b/phantom/phantom/detectors/detector_hamer.py
@@ -0,0 +1,447 @@
+"""
+Wrapper around HaMeR for hand pose estimation
+"""
+import os
+import logging
+import numpy as np
+from pathlib import Path
+from typing import Optional, Tuple
+
+import cv2
+import torch
+from hamer.utils import recursive_to # type: ignore
+import matplotlib.pyplot as plt
+
+from hamer.models import HAMER, DEFAULT_CHECKPOINT # type: ignore
+import sys
+import os
+# Add the phantom-hamer directory to Python path for vitpose_model import
+hamer_path = os.path.join(os.path.dirname(__file__), '..', '..', 'submodules', 'phantom-hamer')
+if hamer_path not in sys.path:
+ sys.path.insert(0, hamer_path)
+from vitpose_model import ViTPoseModel # type: ignore
+from hamer.datasets.vitdet_dataset import ViTDetDataset # type: ignore
+from hamer.utils.renderer import cam_crop_to_full # type: ignore
+from hamer.utils.geometry import perspective_projection # type: ignore
+from hamer.configs import get_config # type: ignore
+from yacs.config import CfgNode as CN # type: ignore
+
+from phantom.utils.data_utils import get_parent_folder_of_package
+
+logger = logging.getLogger(__name__)
+
+THUMB_VERTEX = 756
+INDEX_FINGER_VERTEX = 350
+
+class DetectorHamer:
+ """
+ Detector using the HaMeR model for 3D hand pose estimation.
+
+ The detection pipeline consists of:
+ - Initial hand detection using general object detectors
+ - Hand type classification (left/right) using ViTPose
+ - 3D pose estimation using HaMeR
+ - MANO parameters estimation for mesh reconstruction
+
+ Dependencies:
+ - HaMeR model for 3D pose estimation
+ - ViTPose for keypoint detection
+ - DINO and Detectron2 for initial hand detection
+ """
+ def __init__(self):
+ root_dir = get_parent_folder_of_package("hamer")
+ checkpoint_path = Path(root_dir, DEFAULT_CHECKPOINT)
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ self.rescale_factor = 2.0 # Factor for padding the box
+ self.batch_size = 1 # Batch size for inference
+
+ self.model, self.model_cfg = self.load_hamer_model(checkpoint_path, root_dir)
+ self.model.to(self.device)
+ self.model.eval()
+
+ root_dir = "../submodules/phantom-hamer/"
+ vit_dir = os.path.join(root_dir, "third-party/ViTPose/")
+ self.cpm = ViTPoseModel(device=self.device, root_dir=root_dir, vit_dir=vit_dir)
+
+ self.faces_right = self.model.mano.faces
+ self.faces_left = self.faces_right[:,[0,2,1]]
+
+ def detect_hand_keypoints(self,
+ img: np.ndarray,
+ hand_side: str,
+ visualize: bool=False,
+ visualize_3d: bool=False,
+ pause_visualization: bool=True,
+ bboxes: Optional[np.ndarray]=None,
+ is_right: Optional[np.ndarray]=None,
+ kpts_2d_only: Optional[bool]=False,
+ camera_params: Optional[dict]=None) -> Optional[dict]:
+ """
+ Detect hand keypoints in the input image.
+
+ The method performs the following steps:
+ 1. Detect hand bounding boxes using object detectors
+ 2. Optionally refine boxes using ViTPose to determine hand type (left/right)
+ 3. Run HaMeR model to estimate 3D hand pose
+ 4. Project 3D keypoints back to 2D for visualization
+
+ Args:
+ img: Input RGB image as numpy array
+ hand_side: Target hand side to detect (left or right)
+ visualize: If True, displays detection results in a window
+ visualize_3d: If True, shows 3D visualization of keypoints and mesh
+ pause_visualization: If True, waits for key press when visualizing
+ bboxes: Bounding boxes of the hands
+ is_right: Whether the hand is right
+ kpts_2d_only: If True, only cares about 2D keypoints, i.e., use default
+ focal length in HaMeR instead of real camera intrinsics
+ camera_params: Optional camera intrinsics (fx, fy, cx, cy)
+
+ Returns:
+ Dictionary containing:
+ - annotated_img: Image with keypoints drawn
+ - success: Whether detection was successful (21 keypoints found)
+ - kpts_3d: 3D keypoints in camera space
+ - kpts_2d: 2D keypoints projected onto image
+ - verts: 3D mesh vertices
+ - T_cam_pred: Camera transformation matrix
+ - Various camera parameters and MANO pose parameters
+ """
+ if not kpts_2d_only:
+ scaled_focal_length, camera_center = self.get_image_params(img, camera_params)
+ else:
+ scaled_focal_length, camera_center = self.get_image_params(img, camera_params=None)
+
+
+ dataset = ViTDetDataset(self.model_cfg, img, bboxes, is_right, rescale_factor=self.rescale_factor)
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
+
+ list_2d_kpts, list_3d_kpts, list_verts = [], [], []
+ T_cam_pred_all: list[torch.Tensor] = []
+ list_global_orient = []
+ kpts_2d_hamer = None
+ for batch in dataloader:
+ batch = recursive_to(batch, "cuda")
+ with torch.no_grad():
+ out = self.model(batch)
+
+ batch_T_cam_pred_all = DetectorHamer.get_all_T_cam_pred(batch, out, scaled_focal_length)
+
+ for idx in range(len(batch_T_cam_pred_all)):
+ kpts_3d = out["pred_keypoints_3d"][idx].detach().cpu().numpy() # [21, 3]
+ verts = out["pred_vertices"][idx].detach().cpu().numpy() # [778, 3]
+ is_right = batch["right"][idx].cpu().numpy()
+ global_orient = out["pred_mano_params"]["global_orient"][idx].detach().cpu().numpy()
+ hand_pose = out["pred_mano_params"]["hand_pose"][idx].detach().cpu().numpy()
+ list_global_orient.append(global_orient)
+
+ if hand_side == "left":
+ kpts_3d, verts = DetectorHamer.convert_right_hand_keypoints_to_left_hand(kpts_3d, verts)
+
+ T_cam_pred = batch_T_cam_pred_all[idx]
+
+ img_w, img_h = batch["img_size"][idx].float()
+
+ kpts_2d_hamer = DetectorHamer.project_3d_kpt_to_2d(kpts_3d, img_w, img_h, scaled_focal_length,
+ camera_center, T_cam_pred)
+
+ # Keep T_cam_pred as tensor
+ list_2d_kpts.append(kpts_2d_hamer)
+ list_3d_kpts.append(kpts_3d + T_cam_pred.cpu().numpy())
+ list_verts.append(verts + T_cam_pred.cpu().numpy())
+
+ T_cam_pred_all += batch_T_cam_pred_all
+
+ annotated_img = DetectorHamer.visualize_2d_kpt_on_img(
+ kpts_2d=list_2d_kpts[0],
+ img=img,
+ )
+
+ if visualize:
+ if bboxes is not None:
+ cv2.rectangle(annotated_img, (int(bboxes[0][0]), int(bboxes[0][1])), (int(bboxes[0][2]), int(bboxes[0][3])), (0, 255, 0), 2)
+ cv2.imshow("Annotated Image", annotated_img)
+ cv2.waitKey(0 if pause_visualization else 1)
+
+ if visualize_3d:
+ DetectorHamer.visualize_keypoints_3d(annotated_img, list_3d_kpts[0], list_verts[0])
+
+
+ return {
+ "annotated_img": annotated_img,
+ "success": len(list_2d_kpts[0]) == 21,
+ "kpts_3d": list_3d_kpts[0],
+ "kpts_2d": np.rint(list_2d_kpts[0]).astype(np.int32),
+ "verts": list_verts[0],
+ "T_cam_pred": T_cam_pred_all[0],
+ "scaled_focal_length": scaled_focal_length,
+ "camera_center": camera_center,
+ "img_w": img_w,
+ "img_h": img_h,
+ "global_orient": list_global_orient[0],
+ "hand_pose": hand_pose,
+ }
+
+ def get_image_params(self, img: np.ndarray, camera_params: Optional[dict]) -> Tuple[float, torch.Tensor]:
+ """
+ Get the scaled focal length and camera center.
+ """
+ img_w = img.shape[1]
+ img_h = img.shape[0]
+ if camera_params is not None:
+ scaled_focal_length = camera_params["fx"]
+ cx = camera_params["cx"]
+ cy = camera_params["cy"]
+ camera_center = torch.tensor([img_w-cx, img_h-cy])
+ else:
+ scaled_focal_length = (self.model_cfg.EXTRA.FOCAL_LENGTH / self.model_cfg.MODEL.IMAGE_SIZE
+ * max(img_w, img_h))
+ camera_center = torch.tensor([img_w, img_h], dtype=torch.float).reshape(1, 2) / 2.0
+ return scaled_focal_length, camera_center
+
+ @staticmethod
+ def convert_right_hand_keypoints_to_left_hand(kpts, verts):
+ """
+ Convert right hand keypoints/vertices to left hand by mirroring across the Y-Z plane.
+
+ This is done by flipping the X coordinates of both keypoints and vertices.
+ The MANO model internally uses right hand, so this conversion is needed
+ when processing left hands.
+
+ Args:
+ kpts: 3D keypoints [21, 3]
+ verts: 3D mesh vertices [778, 3]
+
+ Returns:
+ Transformed keypoints and vertices
+ """
+ kpts[:,0] = -kpts[:,0]
+ verts[:,0] = -verts[:,0]
+ return kpts, verts
+
+ @staticmethod
+ def visualize_keypoints_3d(annotated_img: np.ndarray, kpts_3d: np.ndarray, verts: np.ndarray) -> None:
+ nfingers = len(kpts_3d) - 1
+ npts_per_finger = 4
+ list_fingers = [np.vstack([kpts_3d[0], kpts_3d[i:i + npts_per_finger]]) for i in range(1, nfingers, npts_per_finger)]
+ finger_colors_bgr = [(0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 0, 255), (0, 255, 255)]
+ finger_colors_rgb = [(color[2], color[1], color[0]) for color in finger_colors_bgr]
+ fig, axs = plt.subplots(1,2, figsize=(20, 10))
+ axs[0] = fig.add_subplot(111, projection='3d')
+ for finger_idx, finger_pts in enumerate(list_fingers):
+ for i in range(len(finger_pts) - 1):
+ color = finger_colors_rgb[finger_idx]
+ axs[0].plot(
+ [finger_pts[i][0], finger_pts[i + 1][0]],
+ [finger_pts[i][1], finger_pts[i + 1][1]],
+ [finger_pts[i][2], finger_pts[i + 1][2]],
+ color=np.array(color)/255.0,
+ )
+ axs[0].scatter(kpts_3d[:, 0], kpts_3d[:, 1], kpts_3d[:, 2])
+ axs[0].scatter(verts[:, 0], verts[:, 1], verts[:, 2])
+ annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
+ axs[1].imshow(annotated_img_rgb)
+
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ ax.imshow(annotated_img_rgb)
+
+ plt.show()
+
+ @staticmethod
+ def get_all_T_cam_pred(batch: dict, out: dict, scaled_focal_length: float) -> torch.Tensor:
+ """
+ Get the camera transformation matrix
+ """
+ multiplier = 2 * batch["right"] - 1
+ pred_cam = out["pred_cam"]
+ pred_cam[:, 1] = multiplier * pred_cam[:, 1]
+ box_center = batch["box_center"].float()
+ box_size = batch["box_size"].float()
+ # NOTE: FOR HaMeR, they are using the img_size as (W, H)
+ W_H_shapes = batch["img_size"].float()
+
+ multiplier = 2 * batch["right"] - 1
+ T_cam_pred_all = cam_crop_to_full(
+ pred_cam, box_center, box_size, W_H_shapes, scaled_focal_length
+ )
+
+ return T_cam_pred_all
+
+ @staticmethod
+ def visualize_2d_kpt_on_img(kpts_2d: np.ndarray, img: np.ndarray) -> np.ndarray:
+ """
+ Plot 2D hand keypoints on the image with finger connections.
+
+ Each finger is drawn with a different color:
+ - Thumb: Green
+ - Index: Blue
+ - Middle: Red
+ - Ring: Magenta
+ - Pinky: Cyan
+
+ Args:
+ kpts_2d: 2D keypoints as integers [21, 2]
+ img: Input RGB image
+
+ Returns:
+ Image with keypoints and connections drawn (BGR format)
+ """
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ pts = kpts_2d.astype(np.int32)
+ nfingers = len(pts) - 1
+ npts_per_finger = 4
+ list_fingers = [np.vstack([pts[0], pts[i:i + npts_per_finger]]) for i in range(1, nfingers, npts_per_finger)]
+ finger_colors = [(0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 0, 255), (0, 255, 255)]
+ thickness = 5 if img_bgr.shape[0] > 1000 else 2
+ for finger_idx, finger_pts in enumerate(list_fingers):
+ for i in range(len(finger_pts) - 1):
+ color = finger_colors[finger_idx]
+ cv2.line(
+ img_bgr,
+ tuple(finger_pts[i]),
+ tuple(finger_pts[i + 1]),
+ color,
+ thickness=thickness,
+ )
+
+ cv2.line(img_bgr, [1787, 1522], [1656,1400], (255,0,0), thickness=thickness)
+
+ for pt in pts:
+ cv2.circle(img_bgr, (pt[0], pt[1]), radius=thickness, color=(0,0,0), thickness=thickness-1)
+
+ return img_bgr
+
+
+ @staticmethod
+ def project_3d_kpt_to_2d(kpts_3d: torch.Tensor, img_w: int, img_h: int, scaled_focal_length: float,
+ camera_center: torch.Tensor, T_cam: Optional[torch.Tensor] = None,) -> np.ndarray:
+ """
+ Project 3D keypoints to 2D image coordinates using perspective projection.
+ """
+ batch_size = 1
+
+ rotation = torch.eye(3).unsqueeze(0)
+ assert T_cam is not None
+
+ T_cam = T_cam.cpu()
+ kpts_3d = torch.tensor(kpts_3d).cpu()
+
+ T_cam = T_cam.clone().cuda()
+ kpts_3d = kpts_3d.clone().cuda()
+ rotation = rotation.cuda()
+
+ scaled_focal_length_full = torch.tensor([scaled_focal_length, scaled_focal_length]).reshape(1, 2)
+
+ # IMPORTANT: The perspective_projection function assumes T_cam has not been added to kpts_3d already!
+ kpts_2d = perspective_projection(
+ kpts_3d.reshape(batch_size, -1, 3),
+ rotation=rotation.repeat(batch_size, 1, 1),
+ translation=T_cam.reshape(batch_size, -1),
+ focal_length=scaled_focal_length_full.repeat(batch_size, 1),
+ camera_center=camera_center.repeat(batch_size, 1),
+ ).reshape(batch_size, -1, 2)
+ kpts_2d = kpts_2d[0].cpu().numpy()
+
+ return np.rint(kpts_2d).astype(np.int32)
+
+ @staticmethod
+ def annotate_bboxes_on_img(img: np.ndarray, debug_bboxes: dict) -> np.ndarray:
+ """
+ Annotate bounding boxes on the image.
+
+ :param img: Input image (numpy array)
+ :param debug_bboxes: Dictionary containing different sets of bounding boxes and optional scores
+ :return: Annotated image
+ """
+ color_dict = {
+ "dino_bboxes": (0, 255, 0),
+ "det_bboxes": (0, 0, 255),
+ "refined_bboxes": (255, 0, 0),
+ "filtered_bboxes": (255, 255, 0),
+ }
+ corner_dict = {
+ "dino_bboxes": "top_left",
+ "det_bboxes": "top_right",
+ "refined_bboxes": "bottom_left",
+ "filtered_bboxes": "bottom_right",
+ }
+
+ def draw_bbox_and_label(bbox, label, color, label_pos, include_label=True):
+ """ Helper function to draw the bounding box and add label """
+ cv2.rectangle(
+ img,
+ (int(bbox[0]), int(bbox[1])),
+ (int(bbox[2]), int(bbox[3])),
+ color,
+ 2,
+ )
+ if include_label:
+ cv2.putText(
+ img, label, label_pos,
+ cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA
+ )
+
+ label_pos_dict = {
+ "top_left": lambda bbox: (int(bbox[0]), int(bbox[1]) - 10),
+ "bottom_right": lambda bbox: (int(bbox[2]) - 150, int(bbox[3]) - 10),
+ "top_right": lambda bbox: (int(bbox[2]) - 150, int(bbox[1]) - 10),
+ "bottom_left": lambda bbox: (int(bbox[0]), int(bbox[3]) - 10),
+ }
+
+ for key, value in debug_bboxes.items():
+ # Unpack bboxes and scores
+ if key in ["dino_bboxes", "det_bboxes"]:
+ bboxes, scores = value
+ else:
+ bboxes = value
+ scores = [None] * len(bboxes)
+
+ color = color_dict.get(key, (0, 0, 0))
+ label_pos_fn = label_pos_dict[corner_dict.get(key, "top_left")]
+
+ # Draw each bounding box and its label
+ for idx, bbox in enumerate(bboxes):
+ score_text = f" {scores[idx]:.3f}" if scores[idx] is not None else ""
+ label = key.split("_")[0] + score_text
+
+ # Draw bounding box and label on the image
+ label_pos = label_pos_fn(bbox)
+ if key in ["dino_bboxes", "det_bboxes"] or idx == 0:
+ draw_bbox_and_label(bbox, label, color, label_pos)
+ return img
+
+
+ @staticmethod
+ def load_hamer_model(checkpoint_path: str, root_dir: Optional[str] = None) -> Tuple[HAMER, CN]:
+ """
+ Load the HaMeR model from the checkpoint path.
+ """
+ model_cfg_path = str(Path(checkpoint_path).parent.parent / "model_config.yaml")
+ model_cfg = get_config(model_cfg_path, update_cachedir=True)
+ # update model and params path
+ if root_dir:
+ model_cfg.defrost()
+ model_cfg.MANO.DATA_DIR = os.path.join(root_dir, model_cfg.MANO.DATA_DIR)
+ model_cfg.MANO.MODEL_PATH = os.path.join(root_dir, model_cfg.MANO.MODEL_PATH.replace("./", ""))
+ model_cfg.MANO.MEAN_PARAMS = os.path.join(root_dir, model_cfg.MANO.MEAN_PARAMS.replace("./", ""))
+ model_cfg.freeze()
+
+ # Override some config values, to crop bbox correctly
+ if (model_cfg.MODEL.BACKBONE.TYPE == "vit") and ("BBOX_SHAPE" not in model_cfg.MODEL):
+ model_cfg.defrost()
+ assert (
+ model_cfg.MODEL.IMAGE_SIZE == 256
+ ), f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
+ model_cfg.MODEL.BBOX_SHAPE = [192, 256]
+ model_cfg.freeze()
+
+ # Update config to be compatible with demo
+ if "PRETRAINED_WEIGHTS" in model_cfg.MODEL.BACKBONE:
+ model_cfg.defrost()
+ model_cfg.MODEL.BACKBONE.pop("PRETRAINED_WEIGHTS")
+ model_cfg.freeze()
+
+ model = HAMER.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg)
+ return model, model_cfg
diff --git a/phantom/phantom/detectors/detector_sam2.py b/phantom/phantom/detectors/detector_sam2.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe180fe7652d4bf3739c5a5f9aa90054da356945
--- /dev/null
+++ b/phantom/phantom/detectors/detector_sam2.py
@@ -0,0 +1,240 @@
+"""
+Wrapper around SAM2 for object segmentation
+"""
+import numpy as np
+import pdb
+import os
+import logging
+import requests
+from typing import Tuple, Optional
+from pathlib import Path
+import matplotlib.pyplot as plt
+from matplotlib.axes import Axes
+import cv2
+from PIL import Image
+import torch
+from sam2.build_sam import build_sam2 # type: ignore
+from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore
+from sam2.build_sam import build_sam2_video_predictor # type: ignore
+
+logger = logging.getLogger(__name__)
+
+def download_sam2_ckpt(ckpt_path: str) -> None:
+ url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
+ save_path = Path(ckpt_path)
+ save_path.parent.mkdir(exist_ok=True, parents=True)
+ response = requests.get(url, stream=True)
+ if response.status_code == 200:
+ with open(save_path, "wb") as file:
+ for chunk in response.iter_content(chunk_size=8192):
+ file.write(chunk)
+ logger.info(f"File downloaded successfully and saved to {save_path}")
+ else:
+ logger.info(f"Failed to download the file. Status code: {response.status_code}")
+
+class DetectorSam2:
+ """
+ A detector that uses the SAM2 model for object segmentation in images and videos.
+ """
+ def __init__(self):
+ checkpoint = "../submodules/sam2/checkpoints/sam2_hiera_large.pt"
+ model_cfg = "sam2_hiera_l.yaml"
+
+ if not os.path.exists(checkpoint):
+ download_sam2_ckpt(checkpoint)
+ self.device = "cuda"
+
+ self.video_predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=self.device)
+
+ def segment_video(self, video_dir: Path, bbox: np.ndarray, points: np.ndarray,
+ indices: int, reverse: bool=False, output_bboxes: Optional[np.ndarray]=None):
+ """
+ Segment an object across video frames using SAM2's video tracking capabilities.
+
+ Parameters:
+ video_dir: Directory containing video frames as image files
+ bbox: Bounding box coordinates [x0, y0, x1, y1] for the object to track
+ points: Point(s) on the object to track
+ start_idx: Frame index to start tracking from
+
+ Returns:
+ video_segments: Dictionary mapping frame indices to segmentation masks
+ list_annotated_imgs: Array of frames with the segmented object masked out
+ """
+ frame_names = os.listdir(video_dir)
+ frame_names = sorted(frame_names)
+ with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
+ state = self.video_predictor.init_state(video_path=str(video_dir))
+ self.video_predictor.reset_state(state)
+
+ for point, idx in zip(points, indices):
+ try:
+ if bbox is None or np.all(bbox) == 0:
+ self.video_predictor.add_new_points_or_box(
+ state,
+ frame_idx=int(idx),
+ obj_id=0,
+ points=np.array(point),
+ labels=np.ones(len(point)),
+ )
+ else:
+ self.video_predictor.add_new_points_or_box(
+ state,
+ frame_idx=int(idx),
+ obj_id=0,
+ box=np.array(bbox),
+ points=np.array(point),
+ labels=np.ones(len(point)),
+ )
+ except Exception as e:
+ print("Error in adding new points or box:", e)
+ pdb.set_trace()
+
+ video_segments = {}
+ for (
+ out_frame_idx,
+ out_obj_ids,
+ out_mask_logits,
+ ) in self.video_predictor.propagate_in_video(state, reverse=reverse):
+ video_segments[out_frame_idx] = {
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
+ for i, out_obj_id in enumerate(out_obj_ids)
+ }
+
+ frame_indices = list(video_segments.keys())
+ frame_indices.sort()
+ list_annotated_imgs = {}
+ for out_frame_idx in frame_indices:
+ img = Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))
+ img_arr = np.array(img)
+ mask = video_segments[out_frame_idx][0]
+ if output_bboxes is not None:
+ # Crop the mask to the bounding box
+ output_bbox = output_bboxes[out_frame_idx].astype(np.int32)
+ if output_bbox.sum() > 0:
+ bbox_mask = np.zeros_like(mask)
+ bbox_mask = self._crop_mask_to_bbox(mask, output_bbox)
+ mask = mask * bbox_mask
+ img_arr[mask[0]] = (0, 0, 0)
+ list_annotated_imgs[out_frame_idx] = img_arr
+
+ if output_bboxes is not None:
+ for out_frame_idx in frame_indices:
+ output_bbox = output_bboxes[out_frame_idx].astype(np.int32)
+ mask = video_segments[out_frame_idx][0]
+ mask_ori = mask.copy()
+ if output_bbox.sum() > 0:
+ bbox_mask = np.zeros_like(mask)
+ bbox_mask = self._crop_mask_to_bbox(mask, output_bbox)
+ mask = mask * bbox_mask
+ video_segments[out_frame_idx] = {
+ 0: mask
+ }
+
+ # Fix gpu memory leak
+ torch.cuda.empty_cache()
+
+ return video_segments, list_annotated_imgs
+
+ def _crop_mask_to_bbox(self, mask: np.ndarray, bbox: np.ndarray) -> np.ndarray:
+ """
+ Crop a mask to a bounding box.
+ """
+ margin = 20
+ bbox = bbox.astype(np.int32)
+ x0, y0, x1, y1 = bbox
+ x0 = max(0, x0 - margin)
+ x1 = min(mask.shape[2], x1 + margin)
+ y0 = max(0, y0 - margin)
+ y1 = min(mask.shape[1], y1 + margin)
+ bbox_mask = np.zeros_like(mask)
+ bbox_mask[:, y0:y1, x0:x1] = 1
+ return bbox_mask
+
+ def segment_video_from_mask(self, video_dir: str, mask: np.ndarray, frame_idx: int, reverse=False):
+ """
+ Propagate a segmentation mask through video frames (forward or backward).
+
+ Parameters:
+ video_dir: Directory containing video frames
+ mask: Initial segmentation mask to propagate
+ frame_idx: Frame index where the mask is defined
+ reverse: If True, propagate backward in time; if False, propagate forward
+
+ Returns:
+ frame_indices: List of frame indices where masks were generated
+ video_segments: Dictionary mapping frame indices to segmentation masks
+ """
+ with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
+ state = self.video_predictor.init_state(video_path=video_dir)
+ self.video_predictor.reset_state(state)
+
+ self.video_predictor.add_new_mask(state, frame_idx, 0, mask)
+
+ video_segments = {}
+ mask_prob = {}
+ for (
+ out_frame_idx,
+ out_obj_ids,
+ out_mask_logits,
+ ) in self.video_predictor.propagate_in_video(state, reverse=reverse):
+ mask_prob[out_frame_idx] = torch.mean(torch.sigmoid(out_mask_logits))
+ video_segments[out_frame_idx] = {
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
+ for i, out_obj_id in enumerate(out_obj_ids)
+ }
+
+ frame_indices = list(video_segments.keys())
+ frame_indices.sort()
+ return frame_indices, video_segments
+
+ @staticmethod
+ def show_mask(mask: np.ndarray, ax: Axes, random_color: bool=False, borders: bool = True) -> None:
+ if random_color:
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
+ else:
+ color = np.array([30/255, 144/255, 255/255, 0.6])
+ h, w = mask.shape[-2:]
+ mask = mask.astype(np.uint8)
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
+ if borders:
+ contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
+ # Try to smooth contours
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
+ ax.imshow(mask_image)
+
+
+ @staticmethod
+ def show_masks(image: np.ndarray, masks: np.ndarray, scores: np.ndarray, point_coords: Optional[np.ndarray]=None,
+ box_coords: Optional[np.ndarray]=None, input_labels: Optional[np.ndarray]=None, borders: bool=True) -> None:
+ n_masks = len(masks)
+ fig, axs = plt.subplots(1, n_masks, figsize=(10*n_masks, 10))
+ for i, (mask, score) in enumerate(zip(masks, scores)):
+ axs[i].imshow(image)
+ DetectorSam2.show_mask(mask, axs[i], borders=borders)
+ if point_coords is not None:
+ assert input_labels is not None
+ DetectorSam2.show_points(point_coords, input_labels, axs[i])
+ if box_coords is not None:
+ DetectorSam2.show_box(box_coords, axs[i])
+ if len(scores) > 1:
+ axs[i].set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
+ axs[i].axis('off')
+ plt.show()
+
+ @staticmethod
+ def show_box(box: np.ndarray, ax: Axes) -> None:
+ x0, y0 = box[0], box[1]
+ w, h = box[2] - box[0], box[3] - box[1]
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
+
+
+ @staticmethod
+ def show_points(coords: np.ndarray, labels: np.ndarray, ax: Axes, marker_size: int=375) -> None:
+ pos_points = coords[labels==1]
+ neg_points = coords[labels==0]
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*',
+ s=marker_size, edgecolor='white', linewidth=1.25)
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*',
+ s=marker_size, edgecolor='white', linewidth=1.25)
diff --git a/phantom/phantom/hand.py b/phantom/phantom/hand.py
new file mode 100644
index 0000000000000000000000000000000000000000..a13fe0ecbff77249d2251315106d6160c0de5d20
--- /dev/null
+++ b/phantom/phantom/hand.py
@@ -0,0 +1,805 @@
+"""
+Hand Model Module
+
+This module provides hand modeling for action processors. It converts detected hand
+keypoints into kinematic models that can be used for robot control
+
+Key Components:
+- HandModel: Base class for unconstrained hand kinematic modeling
+- PhysicallyConstrainedHandModel: Extended class with constrained joint and velocity limits
+- Grasp point and orientation calculation for robot end-effector control
+
+The hand model follows the MediaPipe hand landmark convention with 21 keypoints:
+- Wrist (1 point)
+- Thumb (4 points: MCP, PIP, DIP, TIP)
+- Index finger (4 points: MCP, PIP, DIP, TIP)
+- Middle finger (4 points: MCP, PIP, DIP, TIP)
+- Ring finger (4 points: MCP, PIP, DIP, TIP)
+- Pinky finger (4 points: MCP, PIP, DIP, TIP)
+
+Coordinate System:
+- All calculations performed in robot coordinate frame
+- Grasp orientations aligned with robot end-effector conventions
+- Joint rotations represented as rotation matrices and Euler angles
+"""
+
+from typing import Optional, List, Dict, Tuple, Union, Any
+import numpy as np
+import pdb
+import torch
+from scipy.spatial.transform import Rotation
+import logging
+
+from phantom.utils.transform_utils import *
+logger = logging.getLogger(__name__)
+
+class HandModel:
+ """
+ Base class for hand kinematic modeling and trajectory analysis.
+
+ This class provides a kinematic representation of a human hand using 21 keypoints
+ from hand pose estimation. It calculates joint rotations, tracks hand motion over
+ time, and computes grasp points and orientations suitable for robot control.
+
+ The kinematic structure follows a tree topology with the wrist as the root,
+ and each finger as a separate chain. Joint rotations are calculated relative
+ to parent joints using vector alignment methods.
+
+ Key Features:
+ - 21-point hand keypoint processing
+ - Joint rotation calculation using vector alignment
+ - Grasp point computation from thumb-index / thumb-middle finger positioning
+ - End-effector orientation calculation for robot control
+
+ Attributes:
+ robot_name (str): Name of the target robot for coordinate frame alignment
+ kinematic_tree (List[Tuple[int, int]]): Parent-child relationships for hand joints
+ joint_to_neighbors_mapping (Dict[int, Tuple[int, int, int]]): Mapping of joints to their neighbors
+ vertex_positions (List[np.ndarray]): Time series of hand keypoint positions
+ joint_rotations (List[List[np.ndarray]]): Time series of joint rotation matrices
+ grasp_points (List[np.ndarray]): Time series of computed grasp points
+ grasp_oris (List[np.ndarray]): Time series of grasp orientation matrices
+ timestamps (List[float]): Time stamps for each frame
+ num_joints (int): Total number of joints in the hand model
+ joint_rotations_xyz (List[List[np.ndarray]]): Time series of Euler angle representations
+ """
+ def __init__(self, robot_name: str) -> None:
+ """
+ Initialize the hand model with kinematic structure.
+
+ Args:
+ robot_name: Name of the target robot for coordinate alignment
+ """
+ self.robot_name: str = robot_name
+
+ # Define the kinematic tree structure for hand joints
+ # Format: (joint_index, parent_index) where -1 indicates root (wrist)
+ self.kinematic_tree: List[Tuple[int, int]] = [
+ (0, -1), # wrist base (root of the kinematic tree)
+
+ # Thumb chain (4 joints)
+ (1, 0), # thumb mcp
+ (2, 1), # thumb pip
+ (3, 2), # thumb dip
+ (4, 3), # thumb tip
+
+ # Index finger chain (4 joints)
+ (5, 0), # index mcp
+ (6, 5), # index pip
+ (7, 6), # index dip
+ (8, 7), # index tip
+
+ # Middle finger chain (4 joints)
+ (9, 0), # middle mcp
+ (10, 9), # middle pip
+ (11, 10), # middle dip
+ (12, 11), # middle tip
+
+ # Ring finger chain (4 joints)
+ (13, 0), # ring mcp
+ (14, 13), # ring pip
+ (15, 14), # ring dip
+ (16, 15), # ring tip
+
+ # Pinky finger chain (4 joints)
+ (17, 0), # pinky mcp
+ (18, 17), # pinky pip
+ (19, 18), # pinky dip
+ (20, 19), # pinky tip
+ ]
+
+ # Mapping from joint index to (current_vertex, child_vertex, parent_vertex)
+ # This defines the local coordinate system for each joint rotation calculation
+ self.joint_to_neighbors_mapping: Dict[int, Tuple[int, int, int]] = {
+ # Thumb joint mappings
+ 0: (0, 1, -1), # wrist to thumb mcp (no parent)
+ 1: (1, 2, 0), # thumb mcp to pip (parent: wrist)
+ 2: (2, 3, 1), # thumb pip to dip (parent: thumb mcp)
+ 3: (3, 4, 2), # thumb dip to tip (parent: thumb pip)
+
+ # Index finger joint mappings
+ 4: (0, 5, -1), # wrist to index mcp (no parent)
+ 5: (5, 6, 0), # index mcp to pip (parent: wrist)
+ 6: (6, 7, 5), # index pip to dip (parent: index mcp)
+ 7: (7, 8, 6), # index dip to tip (parent: index pip)
+
+ # Middle finger joint mappings
+ 8: (0, 9, -1), # wrist to middle mcp (no parent)
+ 9: (9, 10, 0), # middle mcp to pip (parent: wrist)
+ 10: (10, 11, 9), # middle pip to dip (parent: middle mcp)
+ 11: (11, 12, 10),# middle dip to tip (parent: middle pip)
+
+ # Ring finger joint mappings
+ 12: (0, 13, -1), # wrist to ring mcp (no parent)
+ 13: (13, 14, 0),# ring mcp to pip (parent: wrist)
+ 14: (14, 15, 13),# ring pip to dip (parent: ring mcp)
+ 15: (15, 16, 14),# ring dip to tip (parent: ring pip)
+
+ # Pinky finger joint mappings
+ 16: (0, 17, -1), # wrist to pinky mcp (no parent)
+ 17: (17, 18, 0),# pinky mcp to pip (parent: wrist)
+ 18: (18, 19, 17),# pinky pip to dip (parent: pinky mcp)
+ 19: (19, 20, 18),# pinky dip to tip (parent: pinky pip)
+ }
+
+ self.num_joints: int = len(self.joint_to_neighbors_mapping)
+
+ # Time series data storage
+ self.vertex_positions: List[np.ndarray] = [] # List of (21, 3) arrays for each timestep
+ self.joint_rotations: List[List[np.ndarray]] = [] # List of rotation matrices for each joint
+ self.joint_rotations_xyz: List[List[np.ndarray]] = [] # List of Euler angle representations
+ self.grasp_points: List[np.ndarray] = [] # List of computed grasp points (3D positions)
+ self.grasp_oris: List[np.ndarray] = [] # List of grasp orientation matrices (3x3)
+ self.timestamps: List[float] = [] # List of timestamps for temporal analysis
+
+ def calculate_joint_rotation(self, current_pos: np.ndarray, child_pos: np.ndarray, parent_pos: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Calculate the rotation matrix for a single joint using vector alignment.
+
+ This method computes the rotation that aligns the previous direction vector
+ with the current direction vector. For root joints (no parent), it uses
+ a default upward direction as the reference.
+
+ Args:
+ current_pos: 3D position of the current joint
+ child_pos: 3D position of the child joint
+ parent_pos: 3D position of the parent joint
+
+ Returns:
+ Tuple containing:
+ - rotation_matrix: 3x3 rotation matrix
+ - euler_angles: Rotation as XYZ Euler angles
+ """
+ # Calculate current direction vector (current -> child)
+ current_dir = child_pos - current_pos
+ current_norm = np.linalg.norm(current_dir)
+ if current_norm < 1e-10:
+ return np.eye(3), np.array([0,0,0])
+ current_dir /= current_norm
+
+ # Calculate previous direction vector (parent -> current, or default up)
+ prev_dir = np.array([0.0, 0.0, 1.0]) if parent_pos is None else current_pos - parent_pos
+ prev_norm = np.linalg.norm(prev_dir)
+ if prev_norm < 1e-10:
+ return np.eye(3), np.array([0,0,0])
+ prev_dir /= prev_norm
+
+ # Check if vectors are already aligned (no rotation needed)
+ if np.abs((np.abs(np.dot(current_dir, prev_dir)) - 1)) < 1e-8:
+ return np.eye(3), np.array([0,0,0])
+
+ # Calculate rotation that aligns prev_dir with current_dir
+ rotation, _ = Rotation.align_vectors([current_dir], [prev_dir])
+ return rotation.as_matrix(), rotation.as_euler('xyz')
+
+ def calculate_frame_rotations(self, vertices: np.ndarray) -> Tuple[List[np.ndarray], List[np.ndarray]]:
+ """
+ Calculate rotation matrices for all joints in a single frame.
+
+ This method processes all joints in the hand and computes their rotations
+ based on the kinematic structure and current vertex positions.
+
+ Args:
+ vertices: Hand keypoints, shape (21, 3)
+
+ Returns:
+ Tuple containing:
+ - rotation_matrices: List of 3x3 rotation matrices
+ - euler_angles: List of XYZ Euler angle arrays
+ """
+ rotations, rotations_xyz = zip(*[
+ self.calculate_joint_rotation(vertices[m[0]], vertices[m[1]],
+ None if m[2] == -1 else vertices[m[2]])
+ for m in self.joint_to_neighbors_mapping.values()
+ ])
+ return list(rotations), list(rotations_xyz)
+
+ def calculate_angular_velocity(self, joint_idx: int, t1: int, t2: int) -> np.ndarray:
+ """
+ Calculate angular velocity for a specific joint between two time frames.
+
+ Angular velocity is computed as the rotation vector difference divided
+ by the time difference between frames.
+
+ Args:
+ joint_idx: Index of the joint
+ t1: Index of the first time frame
+ t2: Index of the second time frame
+
+ Returns:
+ Angular velocity vector (3,) in rad/s
+ """
+ dt = self.timestamps[t2] - self.timestamps[t1]
+ if dt == 0:
+ return np.zeros(3)
+
+ # Get rotation matrices for the two time frames
+ R1, R2 = self.joint_rotations[t1][joint_idx], self.joint_rotations[t2][joint_idx]
+
+ # Calculate relative rotation and convert to angular velocity
+ R_relative = Rotation.from_matrix(R2) * Rotation.from_matrix(R1).inv()
+ return R_relative.as_rotvec() / dt
+
+ def calculate_frame_angular_velocities(self, current_frame_idx: int) -> np.ndarray:
+ """
+ Calculate angular velocities for all joints at the current frame.
+
+ This method computes the angular velocity vectors for all joints by
+ comparing rotations with the previous frame. Returns zeros for the
+ first frame since no previous frame exists.
+
+ Args:
+ current_frame_idx: Index of the current frame. Must be > 0.
+
+ Returns:
+ Array of angular velocity vectors (shape: num_joints x 3)
+ Each row contains [wx, wy, wz] for one joint.
+ Returns zeros if current_frame_idx == 0.
+ """
+ if current_frame_idx == 0:
+ return np.zeros((self.num_joints, 3))
+
+ prev_frame_idx = current_frame_idx - 1
+
+ # Vectorized calculation for all joints
+ velocities = np.array([
+ self.calculate_angular_velocity(joint_idx, prev_frame_idx, current_frame_idx)
+ for joint_idx in range(self.num_joints)
+ ])
+
+ return velocities
+
+ def calculate_grasp_plane(self, vertices: np.ndarray) -> np.ndarray:
+ """
+ Calculate the plane that best fits through a set of hand vertices.
+
+ This method uses Singular Value Decomposition (SVD) to find the plane.
+ The plane is typically fitted through thumb and index finger points.
+
+ Args:
+ vertices: Set of 3D points to fit plane through, shape (N, 3)
+
+ Returns:
+ Plane coefficients [a, b, c, d] for ax + by + cz + d = 0
+ """
+ # Create augmented matrix with homogeneous coordinates for plane fitting
+ A = np.c_[vertices[:, 0], vertices[:, 1], vertices[:, 2], np.ones(vertices.shape[0])]
+
+ # Right-hand side is zeros for the plane equation ax + by + cz + d = 0
+ b = np.zeros(vertices.shape[0])
+
+ # Use SVD to solve the least squares problem
+ U, S, Vt = np.linalg.svd(A)
+
+ # Plane coefficients are in the last row of Vt (smallest singular value)
+ plane_coeffs = Vt[-1, :]
+
+ # Normalize coefficients for easier interpretation (unit normal vector)
+ plane_coeffs = plane_coeffs / np.linalg.norm(plane_coeffs[:3])
+
+ return plane_coeffs # [a, b, c, d]
+
+ def calculate_grasp_point(self, grasp_plane: np.ndarray, vertices: np.ndarray) -> np.ndarray:
+ """
+ Calculate the optimal grasp point for robot end-effector positioning.
+
+ The grasp point is computed as the midpoint between projected thumb tip
+ and index finger tip on the grasp plane. This provides a stable reference
+ point for robot grasping operations.
+
+ Args:
+ grasp_plane: Plane coefficients [a, b, c, d]
+ vertices: Hand keypoints, shape (21, 3)
+
+ Returns:
+ 3D grasp point coordinates
+ """
+ # Project fingertips onto the grasp plane
+ thumb_pt = project_point_to_plane(vertices[4], grasp_plane)
+ index_pt = project_point_to_plane(vertices[8], grasp_plane)
+
+ # Compute midpoint as the grasp reference
+ hand_ee_pt = np.mean([thumb_pt, index_pt], axis=0)
+ return hand_ee_pt
+
+ def add_frame(self, vertices: np.ndarray, timestamp: float, hand_detected: bool = True) -> None:
+ """
+ Add a new frame of vertex positions and calculate corresponding data.
+
+ This is the main method for processing hand data over time. It computes
+ grasp points, orientations, and stores all relevant information for
+ the current timestep.
+
+ Args:
+ vertices: Array of 21 3D vertex positions
+ timestamp: Time of the frame in seconds
+ hand_detected: Whether hand was successfully detected
+ """
+ if len(vertices) != 21:
+ raise ValueError(f"Expected 21 vertices, got {len(vertices)}")
+
+ # Handle frames without hand detection
+ if not hand_detected:
+ self.vertex_positions.append(np.zeros((21, 3)))
+ self.grasp_points.append(np.zeros(3))
+ self.grasp_oris.append(np.eye(3))
+ self.timestamps.append(timestamp)
+ return
+
+ # Extract key finger tip positions
+ thumb_tip = vertices[4]
+ index_tip = vertices[8]
+ middle_tip = vertices[12]
+
+ # Calculate grasp point as midpoint between thumb and middle finger tips
+ control_point = (thumb_tip + middle_tip) / 2
+ grasp_pt = control_point
+
+ # Calculate gripper orientation from thumb-index finger configuration
+ gripper_ori, _ = HandModel.get_gripper_orientation(thumb_tip, index_tip, vertices)
+
+ # Apply 90-degree rotation to align with robot gripper convention
+ rot_90_deg = Rotation.from_euler('Z', 90, degrees=True).as_matrix()
+ grasp_ori = gripper_ori @ rot_90_deg
+
+ # Store all frame data
+ self.vertex_positions.append(vertices)
+ self.grasp_points.append(grasp_pt)
+ self.grasp_oris.append(grasp_ori)
+ self.timestamps.append(timestamp)
+
+
+ def get_joint_data(self, joint_idx: int) -> Dict[str, Union[List[float], List[np.ndarray]]]:
+ """
+ Get all trajectory data for a specific joint across all frames.
+
+ Args:
+ joint_idx: Index of the joint
+
+ Returns:
+ Dictionary containing joint trajectory data with keys:
+ - 'timestamps': List of time stamps
+ - 'rotations': List of rotation matrices for this joint
+ """
+ return {
+ 'timestamps': self.timestamps,
+ 'rotations': [frame[joint_idx] for frame in self.joint_rotations],
+ }
+
+ @staticmethod
+ def get_parallel_plane(a: float, b: float, c: float, d: float, dist: float) -> Tuple[float, float, float, float]:
+ """
+ Calculate coefficients of a plane parallel to the given plane at specified distance.
+
+ This utility method is useful for creating offset grasp planes that account
+ for gripper thickness or provide clearance during grasping operations.
+
+ Parameters:
+ a, b, c, d: Coefficients of the original plane ax + by + cz + d = 0
+ dist: Distance between planes (positive moves in normal direction)
+
+ Returns:
+ (a, b, c, d_new) coefficients of the parallel plane
+ """
+ # Calculate the magnitude of the normal vector
+ normal_magnitude = np.sqrt(a**2 + b**2 + c**2)
+
+ # Parallel plane has same normal direction, only d changes
+ d_new = d - dist * normal_magnitude
+
+ return (a, b, c, d_new)
+
+ @staticmethod
+ def get_gripper_orientation(thumb_tip: np.ndarray, index_tip: np.ndarray, vertices: np.ndarray, grasp_plane: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Compute robot gripper orientation matrix from hand keypoints and fingertip positions.
+
+ This method calculates a coordinate frame suitable for robot gripper control
+ based on the relative positions of thumb, index finger, and wrist. The resulting
+ orientation matrix can be directly used for robot end-effector control.
+
+ Args:
+ thumb_tip: 3D position of thumb tip
+ index_tip: 3D position of index finger tip
+ vertices: All hand keypoints, shape (21, 3)
+ grasp_plane: Plane coefficients [a,b,c,d]
+
+ Returns:
+ Tuple containing:
+ - gripper_orientation: 3x3 rotation matrix
+ - z_axis: Z-axis direction vector of the gripper frame
+ """
+ # Calculate gripper opening direction (thumb to index finger)
+ gripper_direction = thumb_tip - index_tip
+
+ # Calculate gripper reference point (midpoint of fingertips)
+ midpoint = (thumb_tip + index_tip) / 2
+
+ if grasp_plane is None:
+ # Use palm geometry when no plane is provided
+ palm_axis = vertices[5] - midpoint # index MCP to midpoint
+ x_axis = gripper_direction / max(np.linalg.norm(gripper_direction), 1e-10)
+ z_axis = -palm_axis / max(np.linalg.norm(palm_axis), 1e-10)
+ else:
+ # Use grasp plane for orientation calculation
+ palm_axis = project_point_to_plane(vertices[0], grasp_plane) - project_point_to_plane(vertices[1], grasp_plane)
+ z_axis = -palm_axis / max(np.linalg.norm(palm_axis), 1e-10)
+ x_axis = np.cross(grasp_plane[:3], z_axis)
+ x_axis /= max(np.linalg.norm(x_axis), 1e-10)
+
+ # Compute y-axis
+ y_axis = np.cross(z_axis, x_axis)
+ y_axis /= max(np.linalg.norm(y_axis), 1e-10)
+
+ # Ensure orthogonality by recalculating z_axis
+ z_axis = np.cross(x_axis, y_axis)
+ z_axis /= max(np.linalg.norm(z_axis), 1e-10)
+
+ # Check orientation consistency with palm direction
+ if type(palm_axis) == torch.Tensor:
+ palm_axis = palm_axis.cpu().numpy()
+ if z_axis @ palm_axis > 0:
+ x_axis, y_axis, z_axis = -x_axis, -y_axis, -z_axis
+
+ # Construct orientation matrix
+ gripper_ori = np.column_stack([x_axis, y_axis, z_axis])
+
+ # Ensure proper handedness (right-handed coordinate system)
+ if np.linalg.det(gripper_ori) < 0:
+ x_axis = -x_axis # Flip one axis to fix handedness
+ gripper_ori = np.column_stack([x_axis, y_axis, z_axis])
+
+ # Verify determinant for debugging
+ det = np.linalg.det(gripper_ori)
+ if det < 0.9:
+ pdb.set_trace()
+
+ return gripper_ori, z_axis
+
+
+class PhysicallyConstrainedHandModel(HandModel):
+ """
+ Extended hand model with physical constraints and realistic joint limits.
+
+ This class builds upon the base HandModel by adding realistic constraints
+ that enforce physically plausible hand poses and motion. It includes:
+ - Joint angle limits based on human hand anatomy
+ - Angular velocity constraints for smooth motion
+ - Pose reconstruction with constraint enforcement
+ - Enhanced grasp point calculation with plane-based refinement
+
+ Constrained hand model is used in Phantom
+
+ Key Constraints:
+ - Anatomically correct joint limits for each finger joint
+ - Velocity limiting to prevent jerky motions
+ - Iterative pose refinement with constraint satisfaction
+ - More robust grasp plane calculation and orientation alignment
+
+ Attributes:
+ joint_limits (Dict[int, Tuple[float, ...]]): Joint angle limits for each joint in radians
+ max_angular_velocity (float): Maximum allowed angular velocity in rad/s
+ """
+ def __init__(self, robot_name: str) -> None:
+ """
+ Initialize the physically constrained hand model.
+
+ Args:
+ robot_name: Name of the target robot for coordinate alignment
+ """
+ super().__init__(robot_name)
+
+ # Define joint rotation limits (in radians) for each joint
+ # Format: (min_x, max_x, min_y, max_y, min_z, max_z) for XYZ Euler angles
+ small_angle = np.pi/40 # Small constraint for fine motor control
+
+ self.joint_limits: Dict[int, Tuple[float, float, float, float, float, float]] = {
+ # Thumb joints - more flexible due to opposable nature
+ 0: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to thumb mcp
+ 1: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # thumb mcp to pip
+ 2: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # thumb pip to dip
+ 3: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # thumb dip to tip
+
+ # Index finger joints - moderate constraints
+ 4: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to index mcp
+ 5: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # index mcp to pip
+ 6: (-small_angle, small_angle, -np.pi/8, np.pi/8, -small_angle, small_angle), # index pip to dip
+ 7: (-small_angle, small_angle, -np.pi/8, np.pi/8, -small_angle, small_angle), # index dip to tip
+
+ # Middle finger joints - tighter constraints for stability
+ 8: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to middle mcp
+ 9: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # middle mcp to pip
+ 10: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # middle pip to dip
+ 11: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # middle dip to tip
+
+ # Ring finger joints - similar to middle finger
+ 12: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to ring mcp
+ 13: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # ring mcp to pip
+ 14: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # ring pip to dip
+ 15: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # ring dip to tip
+
+ # Pinky finger joints - most constrained due to size
+ 16: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to pinky mcp
+ 17: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # pinky mcp to pip
+ 18: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # pinky pip to dip
+ 19: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # pinky dip to tip
+ }
+
+ # Maximum angular velocity constraint (2π rad/s = 360°/s)
+ self.max_angular_velocity: float = np.pi * 2
+
+ def reconstruct_vertices(self, input_vertices: np.ndarray, rotations: List[np.ndarray]) -> np.ndarray:
+ """
+ Reconstruct vertex positions from base vertex and constrained rotations.
+
+ This method applies the kinematic chain to reconstruct hand vertex positions
+ while respecting the calculated bone lengths from the input vertices.
+ This ensures consistent hand proportions while applying constraints.
+
+ Args:
+ input_vertices: Original vertex positions, shape (21, 3)
+ rotations: List of constrained rotation matrices
+
+ Returns:
+ Reconstructed vertex positions, shape (21, 3)
+ """
+ vertices = np.zeros((21, 3))
+ vertices[0] = input_vertices[0] # Wrist position remains fixed
+
+ # Calculate bone lengths from original vertices to maintain proportions
+ bone_lengths: Dict[Tuple[int, int], float] = {}
+ min_bone_length = 1e-6 # Minimum length to avoid numerical issues
+
+ # Extract bone lengths from the kinematic chain
+ for current in range(self.num_joints):
+ mapping = self.joint_to_neighbors_mapping[current]
+ current_vertex = mapping[0]
+ child_vertex = mapping[1]
+ parent_vertex = mapping[2]
+
+ # Calculate bone length for current->child connection
+ if child_vertex != -2:
+ length = np.linalg.norm(input_vertices[child_vertex] - input_vertices[current_vertex])
+ bone_lengths[(current_vertex, child_vertex)] = max(length, min_bone_length)
+
+ # Reconstruct positions following the kinematic chain
+ for current in range(self.num_joints):
+ mapping = self.joint_to_neighbors_mapping[current]
+ current_vertex = mapping[0]
+ child_vertex = mapping[1]
+ parent_vertex = mapping[2]
+
+ if child_vertex == -2:
+ continue
+
+ # Get positions and rotation for this joint
+ parent_pos = vertices[parent_vertex]
+ current_pos = vertices[current_vertex]
+ rotation = rotations[current]
+
+ # Determine reference direction for rotation application
+ if parent_vertex == -1:
+ # Root joints use upward direction as reference
+ prev_dir = np.array([0, 0, 1])
+ else:
+ # Use direction from parent to current vertex
+ prev_dir = vertices[current_vertex] - vertices[parent_vertex]
+ prev_dir = prev_dir / np.linalg.norm(prev_dir)
+
+ # Apply rotation to get new direction
+ current_dir = rotation @ prev_dir
+
+ # Position child vertex using calculated bone length
+ bone_length = bone_lengths[(current_vertex, child_vertex)]
+ vertices[child_vertex] = current_pos + current_dir * bone_length
+
+ return vertices
+
+ def constrain_rotation(self, rotation_matrix: np.ndarray, joint_idx: int) -> np.ndarray:
+ """
+ Apply joint angle constraints to a rotation matrix.
+
+ This method converts the rotation to Euler angles, clips them to the
+ joint limits, and converts back to a rotation matrix. This ensures
+ all joint angles remain within anatomically realistic ranges.
+
+ Args:
+ rotation_matrix: 3x3 rotation matrix to constrain
+ joint_idx: Index of the joint for limit lookup
+
+ Returns:
+ Constrained 3x3 rotation matrix
+ """
+ try:
+ # Convert rotation matrix to Euler angles
+ rot = Rotation.from_matrix(rotation_matrix)
+ euler = rot.as_euler('xyz')
+
+ # Get joint limits for this joint
+ limits = self.joint_limits[joint_idx]
+
+ # Clip Euler angles to the specified limits
+ constrained_euler = np.clip(euler,
+ [limits[0], limits[2], limits[4]], # min limits
+ [limits[1], limits[3], limits[5]]) # max limits
+
+ # Convert back to rotation matrix if any clipping occurred
+ if not np.allclose(euler, constrained_euler):
+ return Rotation.from_euler('xyz', constrained_euler).as_matrix()
+ return rotation_matrix
+
+ except ValueError:
+ logger.error("Error constraining rotation")
+ # Return identity matrix if rotation is invalid
+ return np.eye(3)
+
+ def constrain_velocity(self, velocity: np.ndarray) -> np.ndarray:
+ """
+ Apply angular velocity constraints to limit motion speed.
+
+ This method ensures that joint angular velocities don't exceed the
+ maximum allowed velocity, preventing jerky or unrealistic motions.
+
+ Args:
+ velocity: Angular velocity vector to constrain
+
+ Returns:
+ Constrained angular velocity vector
+ """
+ velocity_magnitude = np.linalg.norm(velocity)
+ if velocity_magnitude > self.max_angular_velocity:
+ # Scale velocity to maximum while preserving direction
+ return velocity * (self.max_angular_velocity / velocity_magnitude)
+ return velocity
+
+ def add_frame(self, vertices: np.ndarray, timestamp: float, finger_pts: Any) -> None:
+ """
+ Add a new frame with physical constraints applied.
+
+ This method extends the base add_frame functionality by applying
+ joint limits, velocity constraints, and enhanced grasp calculations.
+ The result is a more realistic and stable hand model suitable for
+ robot control applications.
+
+ Args:
+ vertices: Hand keypoints, shape (21, 3)
+ timestamp: Time of the frame in seconds
+ finger_pts: Additional finger point data (currently unused)
+ """
+ # Calculate initial rotations from raw vertex positions
+ rotations, rotations_xyz = self.calculate_frame_rotations(vertices)
+
+ # Apply joint angle constraints to all rotations
+ constrained_rotations: List[np.ndarray] = []
+ for joint_idx, rotation in enumerate(rotations):
+ constrained_rot = self.constrain_rotation(rotation, joint_idx)
+ constrained_rotations.append(constrained_rot)
+
+ # Apply velocity constraints if this is not the first frame
+ if len(self.timestamps) > 0:
+ dt = timestamp - self.timestamps[-1]
+ for joint_idx in range(self.num_joints):
+ # Calculate angular velocity for this joint
+ prev_rot = Rotation.from_matrix(self.joint_rotations[-1][joint_idx])
+ curr_rot = Rotation.from_matrix(constrained_rotations[joint_idx])
+ rel_rot = curr_rot * prev_rot.inv()
+ velocity = rel_rot.as_rotvec() / dt
+
+ # Apply velocity constraint if needed
+ if np.linalg.norm(velocity) > self.max_angular_velocity:
+ # Constrain velocity and reconstruct rotation
+ constrained_velocity = self.constrain_velocity(velocity)
+ delta_rot = Rotation.from_rotvec(constrained_velocity * dt)
+ new_rot = delta_rot * prev_rot
+ constrained_rotations[joint_idx] = new_rot.as_matrix()
+
+ # Reconstruct vertices with constrained rotations
+ constrained_vertices = self.reconstruct_vertices(vertices, constrained_rotations)
+
+ # Extract key points for grasp calculation
+ thumb_tip = constrained_vertices[4]
+ index_tip = constrained_vertices[8]
+
+ # Calculate grasp plane using thumb and index finger regions
+ grasp_plane = self.calculate_grasp_plane(constrained_vertices[3:9])
+
+ # Organize fingers for direction analysis
+ n_fingers = len(constrained_vertices) - 1
+ npts_per_finger = 4
+ list_fingers = [np.vstack([constrained_vertices[0], constrained_vertices[i:i + npts_per_finger]])
+ for i in range(1, n_fingers, npts_per_finger)]
+
+ # Calculate finger direction vector for plane orientation
+ dir_vec = list_fingers[1][1] - list_fingers[-1][1] # index to pinky MCP
+ dir_vec = dir_vec / np.linalg.norm(dir_vec)
+
+ # Ensure consistent plane orientation (normal pointing away from palm)
+ if np.dot(dir_vec, grasp_plane[:3]) > 0:
+ grasp_plane = -grasp_plane
+
+ # Create slightly offset plane for grasp point calculation
+ shifted_grasp_plane = self.get_parallel_plane(*grasp_plane, 0.01)
+ grasp_pt = self.calculate_grasp_point(shifted_grasp_plane, constrained_vertices)
+
+ # Calculate gripper orientation using the grasp plane
+ gripper_ori, _ = HandModel.get_gripper_orientation(thumb_tip, index_tip, constrained_vertices, grasp_plane)
+
+ # Apply coordinate frame transformations for robot compatibility
+ rot_90_deg = Rotation.from_euler('Z', 90, degrees=True).as_matrix()
+ grasp_ori = gripper_ori @ rot_90_deg
+
+ # Apply pitch adjustment
+ angle = -np.pi/18 * 1.0 # -10 degrees
+ grasp_ori = Rotation.from_rotvec(angle * np.array([1, 0, 0])).apply(grasp_ori)
+
+ # Offset grasp point along gripper Z-axis for clearance
+ unit_vectors = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+ transformed_vectors = unit_vectors @ grasp_ori.T
+ grasp_pt = grasp_pt - transformed_vectors[2] * 0.015 # 1.5cm offset
+
+ # Store all frame data
+ self.joint_rotations.append(constrained_rotations)
+ self.joint_rotations_xyz.append(rotations_xyz)
+ self.vertex_positions.append(constrained_vertices)
+ self.grasp_points.append(grasp_pt)
+ self.grasp_oris.append(grasp_ori)
+ self.timestamps.append(timestamp)
+
+
+def get_list_finger_pts_from_skeleton(skeleton_pts: np.ndarray) -> Dict[str, np.ndarray]:
+ """
+ Organize hand skeleton points into finger-specific groups.
+
+ This utility function takes the 21-point hand skeleton and organizes
+ it into a dictionary with separate arrays for each finger. This makes
+ it easier to perform finger-specific calculations and analysis.
+
+ Args:
+ skeleton_pts: Hand skeleton points, shape (21, 3)
+ Points are ordered as: wrist, thumb(4), index(4), middle(4), ring(4), pinky(4)
+
+ Returns:
+ Dictionary with finger names as keys and point arrays as values:
+ - "thumb": Wrist + 4 thumb points, shape (5, 3)
+ - "index": Wrist + 4 index points, shape (5, 3)
+ - "middle": Wrist + 4 middle points, shape (5, 3)
+ - "ring": Wrist + 4 ring points, shape (5, 3)
+ - "pinky": Wrist + 4 pinky points, shape (5, 3)
+ """
+ n_fingers = len(skeleton_pts) - 1 # Exclude wrist point
+ npts_per_finger = 4 # MCP, PIP, DIP, TIP for each finger
+
+ # Create finger arrays by combining wrist with each finger's points
+ list_fingers = [
+ np.vstack([skeleton_pts[0], skeleton_pts[i : i + npts_per_finger]])
+ for i in range(1, n_fingers, npts_per_finger)
+ ]
+
+ # Return organized finger dictionary
+ return {
+ "thumb": list_fingers[0],
+ "index": list_fingers[1],
+ "middle": list_fingers[2],
+ "ring": list_fingers[3],
+ "pinky": list_fingers[4]
+ }
\ No newline at end of file
diff --git a/phantom/phantom/process_data.py b/phantom/phantom/process_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..21e276043b093e33ae09f8503578c803bee873eb
--- /dev/null
+++ b/phantom/phantom/process_data.py
@@ -0,0 +1,243 @@
+import logging
+from enum import Enum
+from tqdm import tqdm
+from joblib import Parallel, delayed # type: ignore
+import hydra
+from omegaconf import DictConfig
+
+from phantom.processors.base_processor import BaseProcessor
+
+logging.basicConfig(level=logging.WARNING, format="%(name)s - %(levelname)s - %(message)s")
+
+class ProcessingMode(Enum):
+ """Enumeration of valid processing modes."""
+ BBOX = "bbox"
+ HAND2D = "hand2d"
+ HAND3D = "hand3d"
+ HAND_SEGMENTATION = "hand_segmentation"
+ ARM_SEGMENTATION = "arm_segmentation"
+ ACTION = "action"
+ SMOOTHING = "smoothing"
+ HAND_INPAINT = "hand_inpaint"
+ ROBOT_INPAINT = "robot_inpaint"
+ ALL = "all"
+
+PROCESSING_ORDER = [
+ "bbox",
+ "hand2d",
+ "arm_segmentation",
+ "hand_segmentation",
+ "hand3d",
+ "action",
+ "smoothing",
+ "hand_inpaint",
+ "robot_inpaint",
+]
+
+PROCESSING_ORDER_EPIC = [
+ "bbox",
+ "hand2d",
+ "arm_segmentation",
+ "action",
+ "smoothing",
+ "hand_inpaint",
+ "robot_inpaint",
+]
+
+def process_one_demo(data_sub_folder: str, cfg: DictConfig, processor_classes: dict) -> None:
+ # Choose processing order based on epic flag
+ processing_order = PROCESSING_ORDER_EPIC if cfg.epic else PROCESSING_ORDER
+
+ # Handle both string and list modes
+ if isinstance(cfg.mode, str):
+ # Handle comma-separated string format
+ if ',' in cfg.mode:
+ selected_modes = []
+ for mode in cfg.mode.split(','):
+ mode = mode.strip()
+ if mode == "all":
+ selected_modes.extend(processing_order)
+ elif mode in processing_order:
+ selected_modes.append(mode)
+ else:
+ selected_modes = [m for m in processing_order if m in cfg.mode or "all" in cfg.mode]
+ else:
+ # For list of modes, use the order provided by user
+ selected_modes = []
+ for mode in cfg.mode:
+ if mode == "all":
+ selected_modes.extend(processing_order)
+ elif mode in processing_order:
+ selected_modes.append(mode)
+
+ for mode in selected_modes:
+ print(f"----------------- {mode.upper()} PROCESSOR -----------------")
+ processor_cls = processor_classes[mode]
+ processor = processor_cls(cfg)
+ try:
+ processor.process_one_demo(data_sub_folder)
+ except Exception as e:
+ print(f"Error in {mode} processing: {e}")
+ if cfg.debug:
+ raise
+
+def process_all_demos(cfg: DictConfig, processor_classes: dict) -> None:
+ # Choose processing order based on epic flag
+ processing_order = PROCESSING_ORDER_EPIC if cfg.epic else PROCESSING_ORDER
+
+ # Handle both string and list modes
+ if isinstance(cfg.mode, str):
+ # Handle comma-separated string format
+ if ',' in cfg.mode:
+ selected_modes = []
+ for mode in cfg.mode.split(','):
+ mode = mode.strip()
+ if mode == "all":
+ selected_modes.extend(processing_order)
+ elif mode in processing_order:
+ selected_modes.append(mode)
+ else:
+ selected_modes = [m for m in processing_order if m in cfg.mode or "all" in cfg.mode]
+ else:
+ # For list of modes, use the order provided by user
+ selected_modes = []
+ for mode in cfg.mode:
+ if mode == "all":
+ selected_modes.extend(processing_order)
+ elif mode in processing_order:
+ selected_modes.append(mode)
+
+ base_processor = BaseProcessor(cfg)
+ all_data_folders = base_processor.all_data_folders.copy()
+ for mode in selected_modes:
+ print(f"----------------- {mode.upper()} PROCESSOR -----------------")
+ processor_cls = processor_classes[mode]
+ processor = processor_cls(cfg)
+ for data_sub_folder in tqdm(all_data_folders):
+ try:
+ processor.process_one_demo(data_sub_folder)
+ except Exception as e:
+ print(f"Error in {mode} processing: {e}")
+ if cfg.debug:
+ raise
+
+def process_all_demos_parallel(cfg: DictConfig, processor_classes: dict) -> None:
+ # Choose processing order based on epic flag
+ processing_order = PROCESSING_ORDER_EPIC if cfg.epic else PROCESSING_ORDER
+
+ # Handle both string and list modes
+ if isinstance(cfg.mode, str):
+ # Handle comma-separated string format
+ if ',' in cfg.mode:
+ selected_modes = []
+ for mode in cfg.mode.split(','):
+ mode = mode.strip()
+ if mode == "all":
+ selected_modes.extend(processing_order)
+ elif mode in processing_order:
+ selected_modes.append(mode)
+ else:
+ selected_modes = [m for m in processing_order if m in cfg.mode or "all" in cfg.mode]
+ else:
+ # For list of modes, use the order provided by user
+ selected_modes = []
+ for mode in cfg.mode:
+ if mode == "all":
+ selected_modes.extend(processing_order)
+ elif mode in processing_order:
+ selected_modes.append(mode)
+
+ base_processor = BaseProcessor(cfg)
+ all_data_folders = base_processor.all_data_folders.copy()
+ for mode in selected_modes:
+ print(f"----------------- {mode.upper()} PROCESSOR -----------------")
+ processor_cls = processor_classes[mode]
+ processor = processor_cls(cfg)
+ Parallel(n_jobs=cfg.n_processes)(
+ delayed(processor.process_one_demo)(data_sub_folder) for data_sub_folder in all_data_folders
+ )
+
+def get_processor_classes(cfg: DictConfig) -> dict:
+ """Initialize the processor classes"""
+ from phantom.processors.bbox_processor import BBoxProcessor
+ from phantom.processors.segmentation_processor import HandSegmentationProcessor, ArmSegmentationProcessor
+ from phantom.processors.hand_processor import Hand2DProcessor, Hand3DProcessor
+ from phantom.processors.action_processor import ActionProcessor
+ from phantom.processors.smoothing_processor import SmoothingProcessor
+ from phantom.processors.robotinpaint_processor import RobotInpaintProcessor
+ from phantom.processors.handinpaint_processor import HandInpaintProcessor
+
+ return {
+ "bbox": BBoxProcessor,
+ "hand2d": Hand2DProcessor,
+ "hand3d": Hand3DProcessor,
+ "hand_segmentation": HandSegmentationProcessor,
+ "arm_segmentation": ArmSegmentationProcessor,
+ "action": ActionProcessor,
+ "smoothing": SmoothingProcessor,
+ "robot_inpaint": RobotInpaintProcessor,
+ "hand_inpaint": HandInpaintProcessor,
+ }
+
+def validate_mode(cfg: DictConfig) -> None:
+ """
+ Validate that the mode parameter contains only valid processing modes.
+
+ Args:
+ cfg: Configuration object containing mode parameter
+
+ Raises:
+ ValueError: If mode contains invalid options
+ """
+ if isinstance(cfg.mode, str):
+ # Handle comma-separated string format
+ if ',' in cfg.mode:
+ modes = [mode.strip() for mode in cfg.mode.split(',')]
+ else:
+ modes = [cfg.mode]
+ else:
+ modes = cfg.mode
+
+ # Get valid modes from enum
+ valid_modes = {mode.value for mode in ProcessingMode}
+ invalid_modes = [mode for mode in modes if mode not in valid_modes]
+
+ if invalid_modes:
+ valid_mode_list = [mode.value for mode in ProcessingMode]
+ raise ValueError(
+ f"Invalid mode(s): {invalid_modes}. "
+ f"Valid modes are: {valid_mode_list}"
+ )
+
+def main(cfg: DictConfig):
+ # Validate mode parameter
+ validate_mode(cfg)
+
+ # Get processor classes
+ processor_classes = get_processor_classes(cfg)
+
+ if cfg.n_processes > 1:
+ process_all_demos_parallel(cfg, processor_classes)
+ elif cfg.demo_num is not None:
+ process_one_demo(cfg.demo_num, cfg, processor_classes)
+ else:
+ process_all_demos(cfg, processor_classes)
+
+@hydra.main(version_base=None, config_path="../configs", config_name="default")
+def hydra_main(cfg: DictConfig):
+ """
+ Main entry point using Hydra configuration.
+
+ Example usage:
+ - Process all demos with bbox: python process_data.py mode=bbox
+ - Process single demo: python process_data.py mode=bbox demo_num=0
+ - Use EPIC dataset: python process_data.py dataset=epic mode=bbox
+ - Parallel processing: python process_data.py mode=bbox n_processes=4
+ - Process multiple modes sequentially: python process_data.py mode=bbox,hand3d
+ - Process with custom order: python process_data.py mode=hand3d,bbox,action
+ - Process with bracket notation (use quotes): python process_data.py "mode=[bbox,hand3d]"
+ """
+ main(cfg)
+
+if __name__ == "__main__":
+ hydra_main()
diff --git a/phantom/phantom/processors/__init__.py b/phantom/phantom/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/phantom/processors/action_processor.py b/phantom/phantom/processors/action_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b7e029d30cfc9d97c142b327b84b1101b22756b
--- /dev/null
+++ b/phantom/phantom/processors/action_processor.py
@@ -0,0 +1,478 @@
+"""
+Action Processor Module
+
+This module processes hand motion capture data and converts it into robot-executable actions.
+It handles both single-arm and bimanual robotic setups, converting detected hand keypoints
+into end-effector positions, orientations, and gripper widths that can be used for robot control.
+
+Key Features:
+- Converts hand keypoints from camera frame to robot frame
+- Supports both unconstrained and physically constrained hand models
+- Handles missing hand detections with interpolation
+- Processes bimanual data with union-based frame selection
+- Generates neutral poses when no hand data is available
+
+The processor follows this pipeline:
+1. Load hand sequence data (keypoints, detection flags)
+2. Convert keypoints to robot coordinate frame
+3. Apply hand model constraints (optional)
+4. Extract end-effector poses and gripper states
+5. Refine actions to handle missing detections
+6. Save processed actions for robot execution
+"""
+
+import os
+import numpy as np
+from typing import Tuple, Optional
+from dataclasses import dataclass
+import logging
+from scipy.spatial.transform import Rotation
+
+from phantom.processors.base_processor import BaseProcessor
+from phantom.processors.phantom_data import HandSequence
+from phantom.processors.paths import Paths
+from phantom.hand import HandModel, PhysicallyConstrainedHandModel, get_list_finger_pts_from_skeleton
+
+logger = logging.getLogger(__name__)
+
+@dataclass
+class EEActions:
+ """
+ Container for bimanual end-effector action data.
+
+ This dataclass holds the processed robot actions for a sequence of timesteps,
+ including 3D positions, 3D orientations, and gripper opening widths.
+
+ Attributes:
+ ee_pts (np.ndarray): End-effector positions, shape (N, 3) in robot frame coordinates
+ ee_oris (np.ndarray): End-effector orientations as rotation matrices, shape (N, 3, 3)
+ ee_widths (np.ndarray): Gripper opening widths in meters, shape (N,)
+ """
+ ee_pts: np.ndarray # End-effector positions (N, 3)
+ ee_oris: np.ndarray # End-effector orientations (N, 3, 3) as rotation matrices
+ ee_widths: np.ndarray # Gripper widths (N,)
+
+class ActionProcessor(BaseProcessor):
+ """
+ Processor for converting hand motion capture data into robot-executable actions.
+
+ This class handles the complete pipeline from raw hand keypoints to refined robot actions.
+ It supports both single-arm and bimanual robotic setups, with intelligent handling of
+ missing hand detections and physically realistic constraints.
+
+ The processor can operate in different modes:
+ - Single arm: Processes only left or right hand data
+ - Bimanual: Processes both hands with union-based frame selection
+
+ Key processing steps:
+ 1. Load hand sequences with 3D keypoints and detection flags
+ 2. Transform keypoints from camera frame to robot frame
+ 3. Fit hand model (optionally with physical constraints)
+ 4. Extract end-effector poses and gripper states
+ 5. Refine actions using last-valid-value interpolation
+ 6. Generate neutral poses for undetected periods
+
+ Attributes:
+ dt (float): Time delta between frames (1/15 seconds for 15Hz processing)
+ bimanual_setup (str): Setup type ("single_arm", "shoulders", etc.)
+ target_hand (str): Which hand to process in single-arm mode ("left"/"right")
+ constrained_hand (bool): Whether to use physically constrained hand model
+ T_cam2robot (np.ndarray): 4x4 transformation matrix from camera to robot frame
+ """
+ def __init__(self, args):
+ # Set processing frequency to 15Hz
+ self.dt = 1/15
+ super().__init__(args)
+
+ def process_one_demo(self, data_sub_folder: str) -> None:
+ """
+ Process a single demonstration recording into robot actions.
+
+ This is the main entry point for processing one demo. It handles both
+ single-arm and bimanual processing modes, loading the raw hand data,
+ converting it to robot actions, and saving the results.
+
+ Args:
+ data_sub_folder (str): Path to the folder containing this demo's data
+ """
+ save_folder = self.get_save_folder(data_sub_folder)
+ paths = self.get_paths(save_folder)
+
+ # Load hand sequence data for both hands
+ left_sequence, right_sequence = self._load_sequences(paths)
+
+ # Handle single-arm processing mode
+ if self.bimanual_setup == "single_arm":
+ self._process_single_arm(left_sequence, right_sequence, paths)
+ else:
+ self._process_bimanual(left_sequence, right_sequence, paths)
+
+ def _process_single_arm(self, left_sequence: HandSequence, right_sequence: HandSequence, paths) -> None:
+ """Process single-arm setup with one target hand."""
+ # Select target hand based on configuration
+ target_sequence = left_sequence if self.target_hand == "left" else right_sequence
+
+ # Process the selected hand sequence
+ target_actions = self._process_hand_sequence(target_sequence, self.T_cam2robot)
+
+ # Get indices where hand was detected for this sequence
+ union_indices = np.where(target_sequence.hand_detected)[0]
+
+ # Refine actions to handle missing detections
+ target_actions_refined = self._refine_actions(target_sequence, target_actions, union_indices, self.target_hand)
+
+ # Save results for the selected hand only
+ if self.target_hand == "left":
+ self._save_results(paths, union_indices=union_indices, left_actions=target_actions_refined)
+ else:
+ self._save_results(paths, union_indices=union_indices, right_actions=target_actions_refined)
+
+ def _process_bimanual(self, left_sequence: HandSequence, right_sequence: HandSequence, paths) -> None:
+ """Process bimanual setup with both hands."""
+ # Process both hand sequences
+ left_actions = self._process_hand_sequence(left_sequence, self.T_cam2robot)
+ right_actions = self._process_hand_sequence(right_sequence, self.T_cam2robot)
+
+ # Combine detection results using OR logic - frame is valid if either hand detected
+ union_indices = np.where(left_sequence.hand_detected | right_sequence.hand_detected)[0]
+
+ # Refine actions for both hands using the union indices
+ left_actions_refined = self._refine_actions(left_sequence, left_actions, union_indices, "left")
+ right_actions_refined = self._refine_actions(right_sequence, right_actions, union_indices, "right")
+
+ # Save results for both hands
+ self._save_results(paths, union_indices, left_actions_refined, right_actions_refined)
+
+
+ def _load_sequences(self, paths) -> Tuple[HandSequence, HandSequence]:
+ """
+ Load hand sequences from disk for both left and right hands.
+
+ HandSequence objects contain the processed keypoint data, detection flags,
+ and other metadata needed for action processing.
+
+ Args:
+ paths: Paths object containing file locations for hand data
+
+ Returns:
+ Tuple[HandSequence, HandSequence]: Left and right hand sequences
+ """
+ return (
+ HandSequence.load(paths.hand_data_left),
+ HandSequence.load(paths.hand_data_right)
+ )
+
+ def _process_hand_sequence(
+ self,
+ sequence: HandSequence,
+ T_cam2robot: np.ndarray,
+ ) -> EEActions:
+ """
+ Process a single hand sequence into end-effector actions.
+
+ This method performs the following processing pipeline for one hand:
+ 1. Transform keypoints from camera frame to robot frame
+ 2. Fit a hand model to the keypoint sequence
+ 3. Extract end-effector poses and gripper states
+
+ Args:
+ sequence (HandSequence): Hand keypoint sequence with detection flags
+ T_cam2robot (np.ndarray): 4x4 transformation matrix from camera to robot frame
+
+ Returns:
+ EEActions: Processed end-effector positions, orientations, and gripper widths
+ """
+ # Convert keypoints from camera frame to robot frame coordinates
+ kpts_3d_cf = sequence.kpts_3d # Camera frame keypoints
+ kpts_3d_rf = ActionProcessor._convert_pts_to_robot_frame(
+ kpts_3d_cf,
+ T_cam2robot
+ )
+
+ # Create and fit hand model to the keypoint sequence
+ hand_model = self._get_hand_model(kpts_3d_rf, sequence.hand_detected)
+
+ # Extract end-effector poses and gripper states from fitted model
+ kpts_3d, ee_pts, ee_oris = self._get_model_keypoints(hand_model)
+
+ # Compute gripper opening distances from fingertip positions
+ ee_widths = self._compute_gripper_distances(
+ kpts_3d,
+ sequence.hand_detected
+ )
+
+ return EEActions(
+ ee_pts=ee_pts,
+ ee_oris=ee_oris,
+ ee_widths=ee_widths,
+ )
+
+ def _get_hand_model(self, kpts_3d_rf: np.ndarray, hand_detected: np.ndarray) -> HandModel | PhysicallyConstrainedHandModel:
+ """
+ Create and fit a hand model to the keypoint sequence.
+
+ The hand model can be either unconstrained (simple fitting) or physically
+ constrained (enforces realistic hand poses and robot constraints).
+
+ Args:
+ kpts_3d_rf (np.ndarray): Hand keypoints in robot frame, shape (N, 21, 3)
+ hand_detected (np.ndarray): Boolean array indicating valid detections, shape (N,)
+
+ Returns:
+ HandModel | PhysicallyConstrainedHandModel: Fitted hand model with trajectory data
+ """
+ # Choose hand model type based on configuration
+ if self.constrained_hand:
+ hand_model = PhysicallyConstrainedHandModel(self.robot)
+ else:
+ hand_model = HandModel(self.robot)
+
+ # Add each frame to the model for trajectory fitting
+ for t_idx in range(len(kpts_3d_rf)):
+ hand_model.add_frame(
+ kpts_3d_rf[t_idx],
+ t_idx * self.dt, # Convert frame index to time
+ hand_detected[t_idx]
+ )
+ return hand_model
+
+ def _get_model_keypoints(self, model: HandModel | PhysicallyConstrainedHandModel) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Extract keypoints and end-effector data from fitted hand model.
+
+ Args:
+ model (HandModel | PhysicallyConstrainedHandModel): Fitted hand model
+
+ Returns:
+ Tuple containing:
+ - kpts_3d (np.ndarray): Model keypoint positions, shape (N, 21, 3)
+ - ee_pts (np.ndarray): End-effector positions, shape (N, 3)
+ - ee_oris (np.ndarray): End-effector orientations, shape (N, 3, 3)
+ """
+ kpts_3d = np.array(model.vertex_positions) # All hand keypoints
+ ee_pts = np.array(model.grasp_points) # End-effector positions (palm center)
+ ee_oris = np.array(model.grasp_oris) # End-effector orientations (rotation matrices)
+ return kpts_3d, ee_pts, ee_oris
+
+ def _compute_gripper_distances(
+ self,
+ kpts_3d_rf: np.ndarray,
+ hand_detected: np.ndarray
+ ) -> np.ndarray:
+ """
+ Compute gripper opening distances for all frames in the sequence.
+
+ The gripper distance is calculated as the Euclidean distance between
+ the thumb tip and index finger tip, providing a proxy for gripper state.
+
+ Args:
+ kpts_3d_rf (np.ndarray): Hand keypoints in robot frame, shape (N, 21, 3)
+ hand_detected (np.ndarray): Boolean flags for valid detections, shape (N,)
+
+ Returns:
+ np.ndarray: Gripper distances for each frame, shape (N,)
+ """
+ gripper_dists = np.zeros(len(kpts_3d_rf))
+
+ for idx in range(len(kpts_3d_rf)):
+ if hand_detected[idx]:
+ # Only compute distance for frames with valid hand detection
+ gripper_dists[idx] = ActionProcessor._compute_gripper_opening(
+ kpts_3d_rf[idx]
+ )
+ # Note: Invalid frames remain at 0.0, will be refined later
+ return gripper_dists
+
+ def _refine_actions(
+ self,
+ sequence: HandSequence,
+ actions: EEActions,
+ union_indices: np.ndarray,
+ hand_side: str
+ ) -> EEActions:
+ """
+ Refine actions to handle missing hand detections using last-valid-value interpolation.
+
+ When hand detection fails, this method fills in missing values by carrying forward
+ the last valid pose and gripper state. This creates smooth, executable trajectories
+ even when the vision system temporarily loses tracking.
+
+ Args:
+ sequence (HandSequence): Original hand sequence with detection flags
+ actions (EEActions): Raw actions from hand model
+ union_indices (np.ndarray): Frame indices to include in final trajectory
+ hand_side (str): "left" or "right" for neutral pose generation
+
+ Returns:
+ EEActions: Refined actions with interpolated values for missing detections
+ """
+ # Find frames where this hand was actually detected
+ hand_detected_indices = np.where(sequence.hand_detected)[0]
+
+ # If no valid detections, return neutral pose for entire sequence
+ if len(hand_detected_indices) == 0:
+ return self._get_neutral_actions(hand_side, len(union_indices))
+
+ # Apply carry-forward interpolation
+ return self._apply_carry_forward_interpolation(sequence, actions, union_indices, hand_detected_indices)
+
+ def _apply_carry_forward_interpolation(
+ self,
+ sequence: HandSequence,
+ actions: EEActions,
+ union_indices: np.ndarray,
+ hand_detected_indices: np.ndarray
+ ) -> EEActions:
+ """Apply last-valid-value interpolation to fill missing detections."""
+ # Initialize with first valid detection values
+ first_valid_idx = hand_detected_indices[0]
+ last_valid_pt = actions.ee_pts[first_valid_idx]
+ last_valid_ori = actions.ee_oris[first_valid_idx]
+ last_valid_width = actions.ee_widths[first_valid_idx]
+
+ # Process each frame in the union sequence
+ ee_pts_refined = []
+ ee_oris_refined = []
+ ee_widths_refined = []
+
+ for idx in union_indices:
+ if sequence.hand_detected[idx]:
+ # Update with new valid values when available
+ last_valid_pt = actions.ee_pts[idx]
+ last_valid_ori = actions.ee_oris[idx]
+ last_valid_width = actions.ee_widths[idx]
+
+ # Always append the last valid values (carry-forward for missing frames)
+ ee_pts_refined.append(last_valid_pt)
+ ee_oris_refined.append(last_valid_ori)
+ ee_widths_refined.append(last_valid_width)
+
+ return EEActions(
+ ee_pts=np.array(ee_pts_refined),
+ ee_oris=np.array(ee_oris_refined),
+ ee_widths=np.array(ee_widths_refined),
+ )
+
+ def _get_neutral_actions(self, hand_side: str, n_frames: int) -> EEActions:
+ """
+ Generate neutral pose actions when no hand detection is available.
+
+ Neutral poses place the robot arms in out-of-frame positions.
+
+ Args:
+ hand_side (str): "left" or "right" to determine which neutral pose to use
+ n_frames (int): Number of frames to generate
+
+ Returns:
+ EEActions: Neutral pose actions for the specified number of frames
+ """
+ # Define neutral pose configurations
+ neutral_configs = {
+ "single_arm": {
+ "right": {"pos": [0.2, -0.8, 0.3], "quat": [1, 0.0, 0.0, 0.0]},
+ "left": {"pos": [0.2, 0.8, 0.3], "quat": [1, 0.0, 0.0, 0.0]}
+ },
+ "shoulders": {
+ "right": {"pos": [0.4, -0.5, 0.3], "quat": [-0.7071, 0.0, 0.0, 0.7071]},
+ "left": {"pos": [0.4, 0.5, 0.3], "quat": [0.7071, 0.0, 0.0, 0.7071]}
+ }
+ }
+
+ # Get configuration for current setup and hand
+ config = neutral_configs[self.bimanual_setup][hand_side]
+
+ # Convert to numpy arrays and create rotation matrix
+ neutral_pos = np.array(config["pos"])
+ neutral_ori = Rotation.from_quat(config["quat"], scalar_first=False).as_matrix()
+ neutral_width = 0.085 # Standard gripper opening (8.5cm)
+
+ # Create arrays replicated for all frames
+ return EEActions(
+ ee_pts=np.repeat(neutral_pos.reshape(1, 3), n_frames, axis=0),
+ ee_oris=np.repeat(neutral_ori.reshape(1, 3, 3), n_frames, axis=0),
+ ee_widths=np.full(n_frames, neutral_width)
+ )
+
+ def _save_results(
+ self,
+ paths: Paths,
+ union_indices: np.ndarray,
+ left_actions: Optional[EEActions] = None,
+ right_actions: Optional[EEActions] = None,
+ ) -> None:
+ """
+ Save processed action results to disk in NPZ format.
+
+ The saved files contain all necessary data for robot execution:
+ - union_indices: Valid frame indices in the original sequence
+ - ee_pts: End-effector positions
+ - ee_oris: End-effector orientations (rotation matrices)
+ - ee_widths: Gripper opening widths
+
+ Args:
+ paths (Paths): File path configuration object
+ union_indices (np.ndarray): Valid frame indices
+ left_actions (Optional[EEActions]): Left hand actions to save
+ right_actions (Optional[EEActions]): Right hand actions to save
+ """
+ # Create output directory if it doesn't exist
+ os.makedirs(paths.action_processor, exist_ok=True)
+
+ # Save actions for each hand if provided
+ if left_actions is not None:
+ self._save_hand_actions(paths.actions_left, union_indices, left_actions)
+ if right_actions is not None:
+ self._save_hand_actions(paths.actions_right, union_indices, right_actions)
+
+ def _save_hand_actions(self, base_path: str, union_indices: np.ndarray, actions: EEActions) -> None:
+ """Save actions for a single hand to NPZ file."""
+ file_path = str(base_path).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ np.savez(
+ file_path,
+ union_indices=union_indices,
+ ee_pts=actions.ee_pts,
+ ee_oris=actions.ee_oris,
+ ee_widths=actions.ee_widths
+ )
+
+ @staticmethod
+ def _compute_gripper_opening(skeleton_pts: np.ndarray) -> float:
+ """
+ Compute gripper opening distance from hand keypoints for a single frame.
+
+ The gripper distance is calculated as the Euclidean distance between
+ the thumb tip and index finger tip.
+
+ Args:
+ skeleton_pts (np.ndarray): Hand keypoints for one frame, shape (21, 3)
+
+ Returns:
+ float: Distance between thumb tip and index finger tip in meters
+ """
+ # Extract finger tip positions from the hand skeleton
+ finger_dict = get_list_finger_pts_from_skeleton(skeleton_pts)
+
+ # Compute distance between thumb tip and index finger tip
+ return np.linalg.norm(finger_dict["thumb"][-1] - finger_dict["index"][-1])
+
+ @staticmethod
+ def _convert_pts_to_robot_frame(skeleton_poses_cf: np.ndarray, T_cam2robot: np.ndarray) -> np.ndarray:
+ """
+ Convert hand keypoints from camera frame to robot frame coordinates.
+
+ Args:
+ skeleton_poses_cf (np.ndarray): Hand poses in camera frame, shape (N, 21, 3)
+ T_cam2robot (np.ndarray): 4x4 transformation matrix from camera to robot frame
+
+ Returns:
+ np.ndarray: Hand poses in robot frame, shape (N, 21, 3)
+ """
+ # Convert to homogeneous coordinates by adding ones
+ pts_h = np.ones((skeleton_poses_cf.shape[0], skeleton_poses_cf.shape[1], 1))
+ skeleton_poses_cf_h = np.concatenate([skeleton_poses_cf, pts_h], axis=-1)
+
+ # Apply transformation matrix to convert coordinate frames
+ skeleton_poses_rf_h0 = np.einsum('ij,bpj->bpi', T_cam2robot, skeleton_poses_cf_h)
+
+ # Remove homogeneous coordinate and return 3D points
+ return skeleton_poses_rf_h0[..., :3]
\ No newline at end of file
diff --git a/phantom/phantom/processors/base_processor.py b/phantom/phantom/processors/base_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..12b6c85b0d1420415f833954bd2195341a153a27
--- /dev/null
+++ b/phantom/phantom/processors/base_processor.py
@@ -0,0 +1,209 @@
+import os
+import json
+import logging
+import numpy as np
+import shutil
+import errno
+from typing import Tuple
+from pathlib import Path
+from omegaconf import DictConfig
+
+from phantom.utils.data_utils import get_parent_folder_of_package
+from phantom.utils.image_utils import get_intrinsics_from_json, get_transformation_matrix_from_extrinsics
+from phantom.processors.paths import Paths, PathsConfig
+
+logger = logging.getLogger(__name__)
+
+class BaseProcessor:
+ def __init__(self, cfg: DictConfig):
+ # Store configuration for potential future use
+ self.cfg = cfg
+
+ # Apply configuration to instance attributes
+ self._apply_config(cfg)
+
+ # Validate configuration
+ self._validate_config(cfg)
+
+ # Set up paths and data folders
+ self._setup_paths_and_folders(cfg)
+
+ # Initialize camera parameters
+ self._init_camera_parameters()
+
+ def _apply_config(self, cfg: DictConfig) -> None:
+ """Apply configuration to instance attributes."""
+ # Basic attributes
+ self.input_resolution = cfg.input_resolution
+ self.output_resolution = cfg.output_resolution
+ self.project_folder = get_parent_folder_of_package("phantom")
+ self.debug = cfg.debug
+ self.n_processes = cfg.n_processes
+ self.verbose = cfg.verbose
+ self.skip_existing = cfg.skip_existing
+ self.robot = cfg.robot
+ self.gripper = cfg.gripper
+ self.square = cfg.square
+ self.epic = cfg.epic
+ self.bimanual_setup = cfg.bimanual_setup
+ self.target_hand = cfg.target_hand
+ self.constrained_hand = cfg.constrained_hand
+ self.depth_for_overlay = cfg.depth_for_overlay
+ self.render = cfg.render
+ self.debug_cameras = getattr(cfg, 'debug_cameras', [])
+
+ # Apply bimanual setup logic
+ if self.bimanual_setup != "single_arm":
+ self.target_hand = "both"
+
+ def _validate_config(self, cfg: DictConfig) -> None:
+ """Validate critical configuration parameters."""
+ if cfg.input_resolution <= 0 or cfg.output_resolution <= 0:
+ raise ValueError(f"Resolutions must be positive: input={cfg.input_resolution}, output={cfg.output_resolution}")
+
+ if not os.path.exists(cfg.data_root_dir):
+ raise FileNotFoundError(f"Data root directory not found: {cfg.data_root_dir}")
+
+ if not os.path.exists(cfg.camera_intrinsics):
+ raise FileNotFoundError(f"Camera intrinsics file not found: {cfg.camera_intrinsics}")
+
+ def _setup_paths_and_folders(self, cfg: DictConfig) -> None:
+ """Set up paths configuration and create necessary directories."""
+ # Set up paths configuration
+ self.paths_config = PathsConfig()
+ self.paths_config.config['data_root'] = cfg.data_root_dir
+ self.paths_config.config['processed_root'] = cfg.processed_data_root_dir
+
+ self.data_folder = os.path.join(cfg.data_root_dir, cfg.demo_name)
+ self.processed_data_folder = os.path.join(cfg.processed_data_root_dir, cfg.demo_name)
+
+ # Validate that data folder exists
+ if not os.path.exists(self.data_folder):
+ raise FileNotFoundError(f"Data folder not found: {self.data_folder}")
+
+ os.makedirs(self.processed_data_folder, exist_ok=True)
+
+ # Get all folders in data_folder
+ try:
+ all_data_folders = [d1 for d1 in os.listdir(self.data_folder) if os.path.isdir(os.path.join(self.data_folder, d1))]
+ self.all_data_folders = sorted(all_data_folders, key=lambda x: int(x))
+ self.all_data_folders_idx = {x: idx for idx, x in enumerate(self.all_data_folders)}
+ except OSError as e:
+ if e.errno == errno.EACCES:
+ raise PermissionError(f"Permission denied accessing data folder: {self.data_folder}")
+ elif e.errno == errno.ENOENT:
+ raise FileNotFoundError(f"Data folder not found: {self.data_folder}")
+ else:
+ raise RuntimeError(f"OS error accessing data folder {self.data_folder}: {e}")
+ except ValueError as e:
+ raise ValueError(f"Invalid folder name format in {self.data_folder}. Folders should be numbered: {e}")
+
+ def _init_camera_parameters(self) -> None:
+ """Initialize camera intrinsics and extrinsics."""
+ # Get camera intrinsics and extrinsics
+ self.intrinsics_dict, self.intrinsics_matrix = self.get_intrinsics(self.cfg.camera_intrinsics)
+
+ # Use camera_extrinsics from config if available, otherwise determine from bimanual_setup
+ if hasattr(self.cfg, 'camera_extrinsics') and self.cfg.camera_extrinsics:
+ camera_extrinsics_path = self.cfg.camera_extrinsics
+ else:
+ camera_extrinsics_path = self._get_camera_extrinsics_path()
+
+ self.T_cam2robot, self.extrinsics = self.get_extrinsics(camera_extrinsics_path)
+
+ def _get_camera_extrinsics_path(self) -> str:
+ """Get the appropriate camera extrinsics path based on bimanual setup."""
+ if self.bimanual_setup == "shoulders":
+ return "camera/camera_extrinsics_ego_bimanual_shoulders.json"
+ elif self.bimanual_setup == "single_arm":
+ return "camera/camera_extrinsics.json"
+ else:
+ raise ValueError(f"Invalid bimanual setup: {self.bimanual_setup}. Must be 'single_arm' or 'shoulders'.")
+
+ def get_paths(self, data_path: str) -> Paths:
+ """
+ Get all file paths for a demo.
+
+ Args:
+ data_path: Path to the demo data
+
+ Returns:
+ Paths object containing all file paths
+ """
+ paths = Paths(
+ data_path=Path(data_path),
+ robot_name=self.robot
+ )
+ paths.ensure_directories_exist()
+ return paths
+
+ def get_save_folder(self, data_sub_folder: str) -> str:
+ data_sub_folder_fullpath = os.path.join(self.data_folder, str(data_sub_folder))
+ save_folder = os.path.join(self.processed_data_folder, str(data_sub_folder))
+ # Check existing dirs using os.scandir
+ with os.scandir(self.processed_data_folder) as it:
+ existing_dirs = {entry.name for entry in it if entry.is_dir()}
+ if str(data_sub_folder) not in existing_dirs:
+ shutil.copytree(data_sub_folder_fullpath, save_folder)
+ return save_folder
+
+ def process_one_demo(self, data_sub_folder: str):
+ raise NotImplementedError
+
+ def get_intrinsics(self, intrinsics_path: str) -> Tuple[dict, np.ndarray]:
+ intrinsics_matrix, intrinsics_dict = get_intrinsics_from_json(intrinsics_path)
+ if self.square:
+ intrinsics_dict, intrinsics_matrix = self.update_intrinsics_for_square_image(self.input_resolution,
+ intrinsics_dict,
+ intrinsics_matrix)
+ return intrinsics_dict, intrinsics_matrix
+
+ def get_extrinsics(self, extrinsics_path: str) -> Tuple[np.ndarray, dict]:
+ """Load and process camera extrinsics from JSON file.
+
+ Args:
+ extrinsics_path: Path to the extrinsics JSON file
+
+ Returns:
+ Tuple of (transformation_matrix, extrinsics_dict)
+
+ Raises:
+ FileNotFoundError: If extrinsics file doesn't exist
+ json.JSONDecodeError: If extrinsics file is invalid JSON
+ ValueError: If extrinsics data is invalid
+ """
+ if not os.path.exists(extrinsics_path):
+ raise FileNotFoundError(f"Camera extrinsics file not found: {extrinsics_path}")
+
+ try:
+ with open(extrinsics_path, "r") as f:
+ camera_extrinsics = json.load(f)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON in extrinsics file {extrinsics_path}: {str(e)}")
+
+ try:
+ T_cam2robot = get_transformation_matrix_from_extrinsics(camera_extrinsics)
+ except Exception as e:
+ raise ValueError(f"Failed to process extrinsics data from {extrinsics_path}: {str(e)}")
+
+ return T_cam2robot, camera_extrinsics
+
+ @staticmethod
+ def update_intrinsics_for_square_image(img_h: int, intrinsics_dict: dict,
+ intrinsics_matrix: np.ndarray) -> Tuple[dict, np.ndarray]:
+ """
+ Adjusts camera intrinsic parameters for a square image by modifying the principal point offset.
+
+ Args:
+ img_h (int): Height of the image (assumed to be square).
+ intrinsics_dict (dict): Dictionary of intrinsic parameters.
+ intrinsics_matrix (np.ndarray): Intrinsic matrix.
+
+ Returns:
+ Tuple[dict, np.ndarray]: Updated intrinsic parameters and matrix.
+ """
+ img_w = img_h * 16 // 9
+ offset = (img_w - img_h) // 2
+ intrinsics_dict["cx"] -= offset
+ intrinsics_matrix[0, 2] -= offset
+ return intrinsics_dict, intrinsics_matrix
diff --git a/phantom/phantom/processors/bbox_processor.py b/phantom/phantom/processors/bbox_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e83b54b913ccc762691aad77a3843ff135157d7c
--- /dev/null
+++ b/phantom/phantom/processors/bbox_processor.py
@@ -0,0 +1,851 @@
+"""
+Bounding Box Processor Module
+
+This module provides video processing capabilities for detecting and tracking hand bounding boxes
+in demonstration videos. It serves as the first stage in the hand processing pipeline, providing
+spatial localization data for downstream pose estimation and segmentation tasks.
+
+Key Features:
+- Multiple hand detection methods (DINO, EPIC-KITCHENS integration)
+- Bimanual hand tracking with left/right classification
+- Temporal consistency through outlier filtering and interpolation
+- Spatial constraint validation (edge detection, center positioning)
+- Visualization and annotation generation
+
+Processing Pipeline:
+1. Video loading and validation
+2. Frame-by-frame hand detection using configured detectors
+3. Bounding box classification (left/right) based on spatial positioning
+4. Temporal filtering to remove outliers and large jumps
+5. Gap interpolation for smooth trajectories
+6. Edge distance calculation for quality assessment
+7. Result visualization and storage
+
+The processor supports multiple detection backends:
+- DINO-based detection for general hand detection
+- EPIC-KITCHENS pre-computed detections
+- Configurable confidence thresholds and spatial constraints
+
+Output Data:
+- Hand detection flags per frame (boolean arrays)
+- Bounding box coordinates [x1, y1, x2, y2] per frame
+- Bounding box centers [x, y] per frame
+- Distance metrics to image edges
+- Annotated visualization videos
+"""
+
+import os
+import pickle
+import logging
+import numpy as np
+import mediapy as media
+import cv2
+import itertools
+import time
+import matplotlib.pyplot as plt
+from typing import List, Tuple, Optional, Any, Dict
+from typing_extensions import Literal
+import numpy.typing as npt
+from omegaconf import DictConfig
+
+from phantom.processors.base_processor import BaseProcessor
+from phantom.processors.paths import Paths
+from phantom.processors.phantom_data import hand_side_dict
+
+from phantom.utils.bbox_utils import get_bbox_center, get_bbox_center_min_dist_to_edge
+
+logger = logging.getLogger(__name__)
+
+# Type aliases for better readability
+DetectionResults = Dict[str, npt.NDArray]
+BBoxArray = npt.NDArray[np.float32] # [x1, y1, x2, y2]
+CenterArray = npt.NDArray[np.float32] # [x, y]
+DetectionFlagArray = npt.NDArray[np.bool_]
+HandSide = Literal["left", "right"]
+
+class BBoxProcessor(BaseProcessor):
+ # Detection configuration constants
+ HAND_SIDE_MARGIN = 50 # Pixel margin for hand side classification tolerance
+ OVERLAP_THRESHOLD = 0.3 # Threshold for considering bboxes as overlapping
+ MAX_INTERPOLATION_GAP = 10 # Maximum frames to interpolate over
+ MAX_SPATIAL_JUMP = 200.0 # Maximum allowed pixel jump between detections
+ MAX_JUMP_LOOKAHEAD = 10 # Maximum consecutive distant points to filter
+ DINO_CONFIDENCE_THRESH = 0.2 # Default confidence threshold
+
+ # Visualization constants
+ LEFT_HAND_COLOR = (0, 0, 255) # BGR format - Red for left hand
+ RIGHT_HAND_COLOR = (0, 255, 0) # BGR format - Green for right hand
+ BBOX_THICKNESS = 2 # Thickness of bounding box lines
+
+ """
+ Bounding box detection and tracking processor for hand localization in videos.
+
+ This processor serves as the foundation of the hand processing pipeline by detecting
+ and tracking hand bounding boxes across video frames. It handles both single-arm
+ and bimanual setups.
+
+ The processor employs multiple strategies for reliable detection:
+ - Primary detection using DINO or pre-computed EPIC data
+ - Spatial reasoning for left/right hand classification
+ - Temporal filtering to maintain trajectory consistency
+ - Gap interpolation for handling missing detections
+ - Quality assessment through edge distance metrics
+
+ Attributes:
+ H (int): Video frame height (set during processing)
+ W (int): Video frame width (set during processing)
+ center (int): Horizontal center of the frame for left/right classification
+ margin (int): Pixel margin for hand side classification tolerance
+ confidence_threshold (float): Minimum confidence for valid detections
+ dino_detector: DINO-based hand detector (if not using EPIC data)
+ filtered_hand_detection_data (dict): Processed EPIC detection data
+ sorted_keys (list): Sorted frame indices for EPIC data processing
+ """
+ def __init__(self, cfg: DictConfig) -> None:
+ """
+ Initialize the bounding box processor with configuration parameters.
+
+ Args:
+ cfg: Hydra configuration object containing processing configuration
+ including confidence thresholds, target hands, and dataset type
+ """
+ super().__init__(cfg)
+ # Image dimensions (set when processing video)
+ self.H: int = 0
+ self.W: int = 0
+
+ # Initialize detection backend based on dataset type
+ if not self.epic:
+ from phantom.detectors.detector_dino import DetectorDino
+ self.dino_detector: DetectorDino = DetectorDino("IDEA-Research/grounding-dino-base")
+ else:
+ self.dino_detector: Optional[DetectorDino] = None
+
+ # EPIC-specific attributes
+ self.filtered_hand_detection_data: Dict[str, List[Any]] = {}
+ self.sorted_keys: List[str] = []
+
+ # ============================================================================
+ # COMMON/SHARED METHODS (Used by both Phantom and EPIC modes)
+ # ============================================================================
+
+ def process_one_demo(self, data_sub_folder: str) -> None:
+ """
+ Process a single demonstration video to extract hand bounding boxes.
+
+ Args:
+ data_sub_folder: Path to the demonstration data folder containing the video
+ and any pre-computed hand detection data.
+
+ The method performs the following steps:
+ 1. Loads and validates input video and detection data
+ 2. Processes each frame to detect and classify hand positions
+ 3. Applies post-processing filters for temporal consistency
+ 4. Generates quality metrics and visualizations
+ 5. Saves all results in standardized format
+
+ Raises:
+ FileNotFoundError: If required input files (video, detection data) are not found
+ ValueError: If video frames or hand detection data are invalid
+ """
+ # Setup and validation
+ save_folder = self.get_save_folder(data_sub_folder)
+
+ paths = self.get_paths(save_folder)
+
+ # Load and validate input data
+ imgs_rgb = self._load_video(paths)
+
+ # Process frames based on dataset type
+ if self.epic:
+ self._load_epic_hand_data(paths)
+ detection_results = self._process_epic_frames(imgs_rgb)
+ else:
+ detection_results = self._process_frames(imgs_rgb)
+
+ # Post-process results for temporal consistency
+ processed_results = self._post_process_detections(detection_results)
+
+ # Generate visualization for quality assessment
+ visualization_results = self._generate_visualization(imgs_rgb, processed_results)
+
+ # Save all results to disk
+ self._save_results(paths, processed_results, visualization_results)
+
+
+ def _load_video(self, paths: Paths) -> np.ndarray:
+ """
+ Load and validate video data from the specified path.
+
+ Args:
+ paths: Paths object containing video file locations
+
+ Returns:
+ RGB video frames as array
+
+ Raises:
+ FileNotFoundError: If video file doesn't exist
+ ValueError: If video is empty or corrupted
+ """
+ if not os.path.exists(paths.video_left):
+ raise FileNotFoundError(f"Video file not found: {paths.video_left}")
+
+ imgs_rgb = media.read_video(getattr(paths, f"video_left"))
+ if len(imgs_rgb) == 0:
+ raise ValueError("Empty video file")
+
+ # Store video dimensions for coordinate calculations
+ self.H, self.W, _ = imgs_rgb[0].shape
+ self.center: int = self.W // 2 # Center line for left/right classification
+ return imgs_rgb
+
+ # ============================================================================
+ # PHANTOM-SPECIFIC METHODS (DINO Detection)
+ # ============================================================================
+ def _process_frames(self, imgs_rgb: np.ndarray) -> Dict[str, np.ndarray]:
+ """
+ Process RGB frames using DINO detector for hand detection and classification.
+
+ This method handles the core detection pipeline for non-EPIC datasets,
+ using DINO for hand detection and implementing spatial reasoning for
+ left/right classification.
+
+ Args:
+ imgs_rgb: Array of RGB images with shape (num_frames, height, width, 3)
+
+ Returns:
+ Dictionary containing:
+ - left/right_hand_detected: Boolean arrays indicating hand detection per frame
+ - left/right_bboxes: Bounding box coordinates [x1,y1,x2,y2] per frame
+ - left/right_bboxes_ctr: Bounding box centers [x,y] per frame
+ """
+ num_frames = len(imgs_rgb)
+
+ detection_arrays = self._initialize_detection_arrays(num_frames)
+
+ for idx in range(num_frames):
+ try:
+ # Run DINO detection on current frame
+ bboxes, scores = self.dino_detector.get_bboxes(imgs_rgb[idx], "a hand", threshold=self.DINO_CONFIDENCE_THRESH, visualize=False)
+ if len(bboxes) == 0:
+ continue
+
+ bboxes = np.array(bboxes)
+ scores = np.array(scores)
+
+ # Process detections for current frame
+ self._process_frame_detections(idx, bboxes, scores, detection_arrays)
+ except Exception as e:
+ logger.warning(f"Frame {idx} processing failed: {str(e)}")
+ continue
+
+ return {
+ 'left_hand_detected': detection_arrays['left_hand_detected'],
+ 'right_hand_detected': detection_arrays['right_hand_detected'],
+ 'left_bboxes': detection_arrays['left_bboxes'],
+ 'right_bboxes': detection_arrays['right_bboxes'],
+ 'left_bboxes_ctr': detection_arrays['left_bboxes_ctr'],
+ 'right_bboxes_ctr': detection_arrays['right_bboxes_ctr'],
+ }
+
+ def _initialize_detection_arrays(self, num_frames: int) -> Dict[str, npt.NDArray]:
+ """
+ Initialize arrays for storing detection results.
+
+ Args:
+ num_frames: Number of frames in the video
+
+ Returns:
+ Dictionary containing pre-allocated arrays for left/right hand detections,
+ bounding boxes, centers, and detection flags
+ """
+ return {
+ 'left_bboxes': np.zeros((num_frames, 4)),
+ 'right_bboxes': np.zeros((num_frames, 4)),
+ 'left_bboxes_ctr': np.zeros((num_frames, 2)),
+ 'right_bboxes_ctr': np.zeros((num_frames, 2)),
+ 'left_hand_detected': np.zeros(num_frames, dtype=bool),
+ 'right_hand_detected': np.zeros(num_frames, dtype=bool)
+ }
+
+ def _process_frame_detections(self, idx: int, bboxes: npt.NDArray, scores: npt.NDArray,
+ detection_arrays: Dict[str, npt.NDArray]) -> None:
+ """
+ Process detections for a single frame.
+
+ Args:
+ idx: Frame index
+ bboxes: Array of detected bounding boxes
+ scores: Array of detection confidence scores
+ detection_arrays: Dictionary to store detection results
+ """
+ if len(bboxes) == 0:
+ return
+
+ # Always select the bounding box with the highest score
+ best_idx = np.argmax(scores)
+ best_bbox = bboxes[best_idx]
+ best_bbox_ctr = get_bbox_center(best_bbox)
+
+ # Assign hand type directly based on self.target_hand
+ if self.target_hand == "left":
+ detection_arrays['left_bboxes'][idx] = best_bbox
+ detection_arrays['left_bboxes_ctr'][idx] = best_bbox_ctr
+ detection_arrays['left_hand_detected'][idx] = True
+ elif self.target_hand == "right":
+ detection_arrays['right_bboxes'][idx] = best_bbox
+ detection_arrays['right_bboxes_ctr'][idx] = best_bbox_ctr
+ detection_arrays['right_hand_detected'][idx] = True
+
+
+ # ============================================================================
+ # EPIC-SPECIFIC METHODS (EPIC Dataset Processing)
+ # ============================================================================
+
+ def _validate_epic_data_structure(self, epic_data: List[Any]) -> bool:
+ """Validate EPIC data structure before processing."""
+ if not epic_data:
+ return False
+
+ # Check if first item has required attributes
+ try:
+ first_item = epic_data[0]
+ if not hasattr(first_item, 'side') or not hasattr(first_item, 'bbox'):
+ logging.warning("EPIC data missing required attributes: 'side' or 'bbox'")
+ return False
+
+ # Check if bbox has required attributes
+ bbox = first_item.bbox
+ required_attrs = ['left', 'right', 'top', 'bottom']
+ if not all(hasattr(bbox, attr) for attr in required_attrs):
+ logging.warning("EPIC bbox missing required attributes: left, right, top, bottom")
+ return False
+
+ return True
+ except Exception as e:
+ logging.warning(f"Error validating EPIC data structure: {str(e)}")
+ return False
+
+ def _load_epic_hand_data(self, paths: Paths) -> Dict[str, Any]:
+ """
+ Load and validate pre-computed hand detection data from EPIC-KITCHENS dataset.
+
+ EPIC-KITCHENS provides pre-computed hand detection annotations that we can
+ use directly instead of running our own detection. This method filters and
+ sorts the data for efficient frame-by-frame processing.
+
+ Args:
+ paths: Paths object containing detection data file location
+
+ Returns:
+ Dictionary of filtered and sorted hand detection data
+
+ Raises:
+ FileNotFoundError: If detection data file doesn't exist
+ """
+ if not os.path.exists(paths.hand_detection_data):
+ raise FileNotFoundError(f"Hand detection data not found: {paths.hand_detection_data}")
+
+ with open(paths.hand_detection_data, 'rb') as f:
+ hand_detection_data = dict(pickle.load(f))
+
+ # Filter out detection objects without valid side information
+ filtered_data = {
+ key: [obj for obj in obj_list if hasattr(obj, 'side')]
+ for key, obj_list in hand_detection_data.items()
+ }
+
+ # Sort by frame index for sequential processing
+ self.filtered_hand_detection_data = dict(sorted(filtered_data.items(), key=lambda x: int(x[0])))
+ self.sorted_keys = sorted(self.filtered_hand_detection_data.keys(), key=lambda k: int(k))
+
+ return self.filtered_hand_detection_data
+
+ def _process_epic_frames(self, imgs_rgb: npt.NDArray[np.uint8]) -> DetectionResults:
+ """
+ Process frames using pre-computed EPIC-KITCHENS hand detection data.
+
+ This method processes EPIC-KITCHENS dataset videos using their provided
+ hand detection annotations, converting them to our standard format while
+ applying spatial validation constraints.
+
+ Args:
+ imgs_rgb: Array of RGB images for dimension reference
+
+ Returns:
+ Dictionary containing detection results in the same format as _process_frames
+ """
+ num_frames = len(imgs_rgb)
+
+ detection_arrays = self._initialize_detection_arrays(num_frames)
+
+ # Process each frame using EPIC detection data
+ for idx in range(num_frames):
+ try:
+ epic_data = self.filtered_hand_detection_data[self.sorted_keys[idx]]
+
+ if len(epic_data) == 0:
+ continue
+
+ # Process frame detections
+ self._process_epic_frame_detections(idx, epic_data, detection_arrays)
+ except KeyError:
+ logger.warning(f"Missing EPIC data for frame {idx}")
+ continue
+ except Exception as e:
+ logger.warning(f"EPIC frame {idx} processing failed: {str(e)}")
+ continue
+
+ return {
+ 'left_hand_detected': detection_arrays['left_hand_detected'],
+ 'right_hand_detected': detection_arrays['right_hand_detected'],
+ 'left_bboxes': detection_arrays['left_bboxes'],
+ 'right_bboxes': detection_arrays['right_bboxes'],
+ 'left_bboxes_ctr': detection_arrays['left_bboxes_ctr'],
+ 'right_bboxes_ctr': detection_arrays['right_bboxes_ctr']
+ }
+
+ def _process_epic_frame_detections(self, idx: int, epic_data: List[Any],
+ detection_arrays: Dict[str, npt.NDArray]) -> None:
+ """Process EPIC detections for a single frame."""
+ # Process left and right hands separately
+ left_detected, left_bbox, left_bbox_ctr = self._process_epic_hand_detection(epic_data, "left")
+ right_detected, right_bbox, right_bbox_ctr = self._process_epic_hand_detection(epic_data, "right")
+
+ # Store results in pre-allocated arrays
+ detection_arrays['left_hand_detected'][idx] = left_detected
+ detection_arrays['right_hand_detected'][idx] = right_detected
+ if left_detected:
+ detection_arrays['left_bboxes'][idx] = left_bbox
+ detection_arrays['left_bboxes_ctr'][idx] = left_bbox_ctr
+ if right_detected:
+ detection_arrays['right_bboxes'][idx] = right_bbox
+ detection_arrays['right_bboxes_ctr'][idx] = right_bbox_ctr
+
+ # Quality check: If hands appear crossed (left hand on right side),
+ # mark both as invalid to avoid confusion
+ if left_detected and right_detected:
+ self._validate_hand_positions(idx, left_bbox_ctr, right_bbox_ctr, detection_arrays)
+
+ def _validate_hand_positions(self, idx: int, left_bbox_ctr: npt.NDArray, right_bbox_ctr: npt.NDArray,
+ detection_arrays: Dict[str, npt.NDArray]) -> None:
+ """Validate that hands are on correct sides of the image."""
+ if left_bbox_ctr[0] > right_bbox_ctr[0]:
+ # Left hand appears to be on the right side - mark both as invalid
+ detection_arrays['left_hand_detected'][idx] = False
+ detection_arrays['right_hand_detected'][idx] = False
+
+ def _process_epic_hand_detection(self,
+ epic_data: List[Any],
+ hand_side: HandSide) -> Tuple[bool, BBoxArray, CenterArray]:
+ """
+ Process EPIC hand detection data for a single frame and hand side.
+
+ This method extracts and validates hand detection data from EPIC annotations,
+ converting normalized coordinates to pixel coordinates and applying spatial
+ validation constraints.
+
+ Args:
+ epic_data: List of detection objects for the current frame
+ hand_side: Either "left" or "right" specifying which hand to process
+
+ Returns:
+ Tuple of (is_detected: bool, bbox: ndarray, bbox_center: ndarray)
+ """
+ if hand_side not in hand_side_dict:
+ raise ValueError(f"Invalid hand side: {hand_side}")
+
+ # Default empty result for failed detections
+ empty_result = (False, np.array([0, 0, 0, 0]), np.array([0, 0]))
+
+ try:
+ # Filter and validate detection data
+ hand_data = self._filter_epic_hand_data(epic_data, hand_side)
+ if not hand_data:
+ return empty_result
+
+ # Validate data structure
+ if not self._validate_epic_data_structure(hand_data):
+ return empty_result
+
+ # Extract and process bounding box
+ bbox, bbox_center = self._extract_epic_bbox(hand_data[0])
+
+ # Validate bounding box coordinates
+ if not self._validate_bbox_coordinates(hand_data[0].bbox, hand_side):
+ return empty_result
+
+ # Apply spatial validation
+ is_valid = self._validate_spatial_position(bbox_center, hand_side)
+ return (is_valid, bbox, bbox_center) if is_valid else empty_result
+
+ except Exception as e:
+ logging.warning(f"Unexpected error processing {hand_side} hand detection: {str(e)}")
+ return empty_result
+
+ def _filter_epic_hand_data(self, epic_data: List[Any], hand_side: HandSide) -> List[Any]:
+ """Filter EPIC detection data for the specified hand side."""
+ return [data for data in epic_data if data.side.value == hand_side_dict[hand_side]]
+
+ def _extract_epic_bbox(self, hand_data: Any) -> Tuple[BBoxArray, CenterArray]:
+ """Extract bounding box and center from EPIC hand detection data."""
+ bbox_cls = hand_data.bbox
+
+ # Convert normalized coordinates to pixel coordinates
+ bbox = np.array([
+ bbox_cls.left * self.W,
+ bbox_cls.top * self.H,
+ bbox_cls.right * self.W,
+ bbox_cls.bottom * self.H
+ ])
+
+ # Calculate center point for spatial validation
+ bbox_center = np.array([
+ (bbox[0] + bbox[2]) / 2,
+ (bbox[1] + bbox[3]) / 2
+ ]).astype(np.int32)
+
+ return bbox, bbox_center
+
+ def _validate_spatial_position(self, bbox_center: CenterArray, hand_side: HandSide) -> bool:
+ """Validate that hand center is on correct side of image."""
+ if hand_side == "left":
+ return bbox_center[0] <= (self.center + self.HAND_SIDE_MARGIN)
+ else: # right
+ return bbox_center[0] >= (self.center - self.HAND_SIDE_MARGIN)
+
+ def _validate_bbox_coordinates(self, bbox_cls: Any, hand_side: HandSide) -> bool:
+ """Validate bounding box coordinates are within valid range [0,1]."""
+ if not (0 <= bbox_cls.left <= 1 and 0 <= bbox_cls.right <= 1 and
+ 0 <= bbox_cls.top <= 1 and 0 <= bbox_cls.bottom <= 1):
+ logging.warning(f"Invalid bbox coordinates detected for {hand_side} hand: "
+ f"left={bbox_cls.left:.3f}, right={bbox_cls.right:.3f}, "
+ f"top={bbox_cls.top:.3f}, bottom={bbox_cls.bottom:.3f}")
+ return False
+ return True
+
+
+ # ============================================================================
+ # UTILITY/HELPER METHODS (General utilities and post-processing)
+ # ============================================================================
+
+
+ def _post_process_detections(self, detection_results: DetectionResults) -> DetectionResults:
+ """
+ Apply post-processing to improve detection temporal consistency.
+
+ This method applies several filters and enhancements to the raw detection
+ results to improve their quality and temporal coherence:
+ 1. Filter out large spatial jumps that indicate tracking errors
+ 2. Interpolate short gaps in detection sequences
+ 3. Calculate quality metrics (distance to image edges)
+
+ Args:
+ detection_results: Raw detection results from frame processing
+
+ Returns:
+ Enhanced detection results with improved temporal consistency
+ """
+ # Filter out large jumps for both hands
+ left_results = self._filter_large_jumps(
+ detection_results['left_hand_detected'],
+ detection_results['left_bboxes'],
+ detection_results['left_bboxes_ctr'],
+ max_jump=self.MAX_SPATIAL_JUMP,
+ lookahead=self.MAX_JUMP_LOOKAHEAD
+ )
+ right_results = self._filter_large_jumps(
+ detection_results['right_hand_detected'],
+ detection_results['right_bboxes'],
+ detection_results['right_bboxes_ctr'],
+ max_jump=self.MAX_SPATIAL_JUMP,
+ lookahead=self.MAX_JUMP_LOOKAHEAD
+ )
+
+ # Interpolate missing detections for smooth trajectories
+ left_results = self._interpolate_detections(*left_results, max_gap=self.MAX_INTERPOLATION_GAP)
+ right_results = self._interpolate_detections(*right_results, max_gap=self.MAX_INTERPOLATION_GAP)
+
+ # Calculate quality metrics: minimum distance from bbox center to image edges
+ left_bbox_min_dist = get_bbox_center_min_dist_to_edge(left_results[1], self.W, self.H)
+ right_bbox_min_dist = get_bbox_center_min_dist_to_edge(right_results[1], self.W, self.H)
+
+ return {
+ 'left_hand_detected': left_results[0],
+ 'right_hand_detected': right_results[0],
+ 'left_bboxes': left_results[1],
+ 'right_bboxes': right_results[1],
+ 'left_bboxes_ctr': left_results[2],
+ 'right_bboxes_ctr': right_results[2],
+ 'left_bbox_min_dist_to_edge': left_bbox_min_dist,
+ 'right_bbox_min_dist_to_edge': right_bbox_min_dist
+ }
+
+ def _generate_visualization(self, imgs_rgb: np.ndarray, results: Dict[str, np.ndarray]) -> List[np.ndarray]:
+ """
+ Generate visualization of detection results for quality assessment.
+
+ Creates annotated frames showing detected bounding boxes for visual
+ inspection of detection quality and temporal consistency.
+
+ Args:
+ imgs_rgb: Original RGB video frames
+ results: Processed detection results
+
+ Returns:
+ List of annotated images with bounding boxes drawn
+ """
+ list_img_annot = []
+ for idx in range(len(imgs_rgb)):
+ left_bbox = None
+ right_bbox = None
+
+ # Prepare bounding boxes for visualization
+ if results['left_hand_detected'][idx] or results['right_hand_detected'][idx]:
+ left_bbox = results['left_bboxes'][idx] if results['left_hand_detected'][idx] else None
+ right_bbox = results['right_bboxes'][idx] if results['right_hand_detected'][idx] else None
+
+ # Generate annotated image
+ img_annot = self.visualize_detections(imgs_rgb[idx], left_bbox, right_bbox, show_image=False)
+ list_img_annot.append(img_annot)
+ return list_img_annot
+
+ def _save_results(self, paths: Paths, results: DetectionResults, visualization_results: List[npt.NDArray[np.uint8]]) -> None:
+ """
+ Save all processed results to disk in standardized format.
+
+ Args:
+ paths: Paths object containing output file locations
+ results: Processed detection results
+ visualization_results: Generated visualization frames
+ """
+ # Create output directory if it doesn't exist
+ if not os.path.exists(paths.bbox_processor):
+ os.makedirs(paths.bbox_processor)
+
+ # Save detection data in compressed NumPy format
+ np.savez(paths.bbox_data, **results)
+
+ # Save visualization video with lossless compression
+ media.write_video(paths.video_bboxes, visualization_results, fps=15, codec="ffv1")
+
+ def _interpolate_detections(self, detected: DetectionFlagArray,
+ bboxes: BBoxArray,
+ centers: CenterArray,
+ max_gap: int = 10) -> Tuple[DetectionFlagArray, BBoxArray, CenterArray]:
+ """
+ Interpolate bounding boxes and detection status for short gaps in tracking.
+
+ This method fills in missing detections using linear interpolation when the
+ gap is small enough to reasonably assume continuous hand motion. This helps
+ create smoother trajectories for downstream processing.
+
+ Args:
+ detected: Boolean array of detection status per frame
+ bboxes: Array of bounding boxes [N, 4] format [x1, y1, x2, y2]
+ centers: Array of bbox centers [N, 2] format [x, y]
+ max_gap: Maximum gap size (in frames) to interpolate over
+
+ Returns:
+ Tuple of (interpolated detection status, interpolated bboxes, interpolated centers)
+ """
+ detected = detected.copy()
+ bboxes = bboxes.copy()
+ centers = centers.copy()
+
+ # Handle single-frame gaps first (most common case)
+ for i in range(1, len(detected) - 1):
+ if not detected[i] and detected[i-1] and detected[i+1]:
+ # Get valid bboxes/centers before and after gap
+ start_bbox = bboxes[i-1]
+ end_bbox = bboxes[i+1]
+ start_center = centers[i-1]
+ end_center = centers[i+1]
+
+ # Linear interpolation with t = 0.5 for single frame
+ interpolated_bbox = 0.5 * (start_bbox + end_bbox)
+ interpolated_center = 0.5 * (start_center + end_center)
+
+ # Validate interpolated values are reasonable
+ if self._is_valid_bbox(interpolated_bbox) and self._is_valid_center(interpolated_center):
+ bboxes[i] = interpolated_bbox
+ centers[i] = interpolated_center
+ detected[i] = True
+
+ # Handle multi-frame gaps
+ non_detect_start = None
+ for i in range(1, len(detected) - 1):
+ # Start of non-detection sequence
+ if detected[i-1] and not detected[i]:
+ non_detect_start = i
+ # End of non-detection sequence
+ elif non_detect_start is not None and not detected[i] and detected[i+1]:
+ non_detect_end = i
+ gap_size = non_detect_end - non_detect_start + 1
+
+ # Only interpolate if gap is small enough and has valid detections on both sides
+ if gap_size <= max_gap:
+ # Get valid bboxes/centers before and after gap
+ start_bbox = bboxes[non_detect_start - 1]
+ end_bbox = bboxes[non_detect_end + 1]
+ start_center = centers[non_detect_start - 1]
+ end_center = centers[non_detect_end + 1]
+
+ # Generate interpolation steps
+ steps = gap_size + 1
+ for j in range(gap_size):
+ t = (j + 1) / steps # Interpolation factor
+
+ # Linear interpolation of bbox coordinates
+ bboxes[non_detect_start + j] = (1 - t) * start_bbox + t * end_bbox
+
+ # Linear interpolation of center coordinates
+ centers[non_detect_start + j] = (1 - t) * start_center + t * end_center
+
+ # Mark as detected
+ detected[non_detect_start + j] = True
+
+ non_detect_start = None
+
+ return detected, bboxes, centers
+
+ def _is_valid_bbox(self, bbox: BBoxArray) -> bool:
+ """Validate that bbox coordinates are reasonable."""
+ if bbox is None or len(bbox) != 4:
+ return False
+ # Check for reasonable bounds (not negative, not too large)
+ return (bbox >= 0).all() and (bbox[:2] < bbox[2:]).all() and bbox.max() < max(self.W, self.H) * 2
+
+ def _is_valid_center(self, center: CenterArray) -> bool:
+ """Validate that center coordinates are reasonable."""
+ if center is None or len(center) != 2:
+ return False
+ # Check for reasonable bounds
+ return (center >= 0).all() and center[0] < self.W * 2 and center[1] < self.H * 2
+
+ def visualize_detections(self, img: npt.NDArray[np.uint8],
+ left_bbox: Optional[npt.NDArray[np.float32]] = None,
+ right_bbox: Optional[npt.NDArray[np.float32]] = None,
+ show_image: bool = True) -> npt.NDArray[np.uint8]:
+ """
+ Visualize hand detections by drawing bounding boxes on the image.
+
+ This method creates annotated images showing detected hand locations with
+ color-coded bounding boxes (red for left hand, green for right hand).
+
+ Args:
+ img: Input RGB image to annotate
+ left_bbox: Left hand bounding box [x1, y1, x2, y2] or None if not detected
+ right_bbox: Right hand bounding box [x1, y1, x2, y2] or None if not detected
+ show_image: Whether to display the image using cv2.imshow
+
+ Returns:
+ The annotated image
+ """
+ # Work directly with the input image (assumed to be in BGR format)
+ img_bgr = img
+
+ # Draw left hand bounding box in red
+ if left_bbox is not None and not np.array_equal(left_bbox, np.array([0, 0, 0, 0])):
+ cv2.rectangle(
+ img_bgr,
+ (int(left_bbox[0]), int(left_bbox[1])),
+ (int(left_bbox[2]), int(left_bbox[3])),
+ self.LEFT_HAND_COLOR,
+ self.BBOX_THICKNESS
+ )
+
+ # Draw right hand bounding box in green
+ if right_bbox is not None and not np.array_equal(right_bbox, np.array([0, 0, 0, 0])):
+ cv2.rectangle(
+ img_bgr,
+ (int(right_bbox[0]), int(right_bbox[1])),
+ (int(right_bbox[2]), int(right_bbox[3])),
+ self.RIGHT_HAND_COLOR,
+ self.BBOX_THICKNESS
+ )
+
+ # Optionally display the image for debugging
+ if show_image:
+ cv2.imshow("Hand Detections", img_bgr)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
+
+ return img_bgr
+
+ @staticmethod
+ def _filter_large_jumps(detected: DetectionFlagArray,
+ bboxes: BBoxArray,
+ centers: CenterArray,
+ max_jump: float = 200.0,
+ lookahead: int = 10) -> Tuple[DetectionFlagArray, BBoxArray, CenterArray]:
+ """
+ Filter out small groups of detections that are spatially inconsistent with the trajectory.
+
+ This method identifies and removes isolated detections that are far from the
+ expected trajectory, which usually indicate false positives or tracking errors.
+ It helps maintain temporal consistency in hand tracking.
+
+ Args:
+ detected: Boolean array of detection status per frame
+ bboxes: Array of bounding boxes [N, 4] format [x1, y1, x2, y2]
+ centers: Array of bbox centers [N, 2] format [x, y]
+ max_jump: Maximum allowed distance (in pixels) between consecutive detections
+ lookahead: Maximum number of consecutive distant points to filter as a group
+
+ Returns:
+ Tuple of (filtered detection status, filtered bboxes, filtered centers)
+ """
+ detected = detected.copy()
+ bboxes = bboxes.copy()
+ centers = centers.copy()
+
+ # Templates for clearing invalid detections
+ empty_bbox = np.array([0, 0, 0, 0])
+ empty_center = np.array([0, 0])
+
+ i = 0
+ while i < len(detected):
+ # Find next detected point to compare against
+ next_valid = i + 1
+
+ if next_valid >= len(detected):
+ break
+
+ # Calculate spatial distance to next detection
+ dist = np.linalg.norm(centers[next_valid] - centers[i])
+
+ if dist > max_jump:
+ # Large jump detected - check if it's part of a small group of outliers
+ distant_points = []
+ ref_center = centers[i] # Use current point as reference
+
+ # Look ahead to find consecutive distant points
+ for j in range(next_valid, len(detected)):
+ curr_dist = np.linalg.norm(centers[j] - ref_center)
+ if curr_dist > max_jump:
+ distant_points.append(j)
+ else:
+ break
+
+ # If we found a small group of distant points, filter them out
+ if len(distant_points) > 0 and len(distant_points) <= lookahead:
+ for idx in distant_points:
+ detected[idx] = False
+ bboxes[idx] = empty_bbox
+ centers[idx] = empty_center
+ logging.warning(f"Filtered out frame {idx} as part of small distant group")
+
+ i = next_valid
+
+ return detected, bboxes, centers
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/phantom/processors/hand_processor.py b/phantom/phantom/processors/hand_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d18a549233d6682a698c4968e43911a1299523a
--- /dev/null
+++ b/phantom/phantom/processors/hand_processor.py
@@ -0,0 +1,675 @@
+"""
+Hand Processor Module
+
+This module converts detected hand bounding boxes into detailed 3D hand poses using
+state-of-the-art pose estimation models, with optional depth-based refinement for improved accuracy.
+
+Processing Pipeline:
+1. Load video frames and bounding box data from previous stage
+2. Apply HaMeR pose estimation within detected bounding boxes
+3. Filter poses based on edge proximity and quality metrics
+4. Optionally refine 3D poses using depth data and segmentation
+5. Generate hand mesh models and extract keypoint trajectories
+6. Save processed hand sequences for downstream tasks
+
+The module supports multiple processing modes:
+- Hand2DProcessor: 2D pose estimation only (faster, camera-based)
+- Hand3DProcessor: Full 3D processing with depth alignment (more accurate, if depth is available)
+
+Output Data:
+- HandSequence objects containing pose trajectories
+- 2D keypoint positions in image coordinates
+- 3D keypoint positions in camera coordinates
+- Hand detection flags per frame
+- Annotated visualization videos
+"""
+
+import glob
+import os
+import logging
+from tqdm import tqdm
+import numpy as np
+import mediapy as media
+import open3d as o3d # type: ignore
+from typing import Tuple, Optional, Dict, Any
+import trimesh
+from collections import defaultdict
+import argparse
+
+from phantom.utils.pcd_utils import get_visible_points, get_pcd_from_points, icp_registration, get_point_cloud_of_segmask, get_3D_points_from_pixels, remove_outliers, get_bbox_of_3d_points, trim_pcd_to_bbox, visualize_pcds
+from phantom.utils.transform_utils import transform_pts
+from phantom.processors.base_processor import BaseProcessor
+from phantom.detectors.detector_hamer import DetectorHamer
+from phantom.processors.phantom_data import HandSequence, HandFrame, hand_side_dict
+from phantom.processors.paths import Paths
+from phantom.processors.segmentation_processor import HandSegmentationProcessor
+
+logger = logging.getLogger(__name__)
+
+class HandBaseProcessor(BaseProcessor):
+ """
+ Base class for hand pose processing using HaMeR detection and optional depth refinement.
+
+ The processor operates on the output of BBoxProcessor, using detected hand bounding boxes
+ to guide pose estimation. It supports both 2D and 3D processing modes, with the 3D mode
+ providing enhanced accuracy through depth sensor integration.
+
+ Processing Workflow:
+ 1. Load video frames and bounding box detection results
+ 2. For each frame with detected hands:
+ - Apply HaMeR pose estimation within bounding box
+ - Validate pose quality (edge proximity, confidence)
+ - Optionally generate hand segmentation masks for depth refinement
+ - Optionally apply depth-based pose refinement
+ 3. Generate temporal hand sequences with smooth trajectories
+ 4. Save processed results and visualization videos
+
+ Attributes:
+ process_hand_masks (bool): Whether to generate hand segmentation masks
+ apply_depth_alignment (bool): Whether to use depth-based pose refinement
+ detector_hamer (DetectorHamer): HaMeR pose estimation model
+ hand_mask_processor: Segmentation processor for hand mask generation
+ H (int): Video frame height
+ W (int): Video frame width
+ imgs_depth (np.ndarray): Depth images for 3D refinement
+ left_masks (np.ndarray): Left hand segmentation masks
+ right_masks (np.ndarray): Right hand segmentation masks
+ """
+ def __init__(self, args: argparse.Namespace) -> None:
+ """
+ Initialize the hand processor with configuration parameters.
+
+ Args:
+ args: Command line arguments containing processing configuration
+ including depth processing flags and model parameters
+ """
+ super().__init__(args)
+ self.process_hand_masks: bool = False
+ self._initialize_detectors()
+ self.hand_mask_processor: Optional[HandSegmentationProcessor] = None
+ self.apply_depth_alignment: bool = False
+
+ def _initialize_detectors(self) -> None:
+ """
+ Initialize all required detection models.
+
+ Sets up the HaMeR detector for hand pose estimation.
+ """
+ self.detector_hamer = DetectorHamer()
+
+ def process_one_demo(self, data_sub_folder: str) -> None:
+ """
+ Process a single demonstration video to extract hand poses and segmentation.
+
+ Args:
+ data_sub_folder: Path to the demonstration data folder containing
+ video files, bounding box data, and optional depth data
+ """
+ save_folder = self.get_save_folder(data_sub_folder)
+
+ paths = self.get_paths(save_folder)
+
+ # Load RGB video frames
+ imgs_rgb = media.read_video(getattr(paths, f"video_left"))
+ self.H, self.W, _ = imgs_rgb[0].shape
+
+ # Load depth data if available (for 3D processing)
+ if os.path.exists(paths.depth):
+ self.imgs_depth = np.load(paths.depth)
+ else:
+ self.imgs_depth = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
+
+ # Load hand segmentation masks if available
+ if os.path.exists(paths.masks_hand_left) and os.path.exists(paths.masks_hand_right):
+ self.left_masks = np.load(paths.masks_hand_left)
+ self.right_masks = np.load(paths.masks_hand_right)
+ else:
+ self.left_masks = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
+ self.right_masks = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
+
+ # Load bounding box detection results from previous stage
+ bbox_data = np.load(paths.bbox_data)
+ left_hand_detected = bbox_data["left_hand_detected"]
+ right_hand_detected = bbox_data["right_hand_detected"]
+ left_bboxes = bbox_data["left_bboxes"]
+ right_bboxes = bbox_data["right_bboxes"]
+
+ # Validate data consistency
+ assert len(left_hand_detected) == len(right_hand_detected)
+ assert len(left_hand_detected) == len(imgs_rgb)
+
+ # Process left and right hand sequences
+ left_sequence = self._process_all_frames(imgs_rgb, left_bboxes, left_hand_detected, "left")
+ right_sequence = self._process_all_frames(imgs_rgb, right_bboxes, right_hand_detected, "right")
+
+ # Generate hand segmentation masks if enabled
+ if self.process_hand_masks:
+ self._get_hand_masks(data_sub_folder, left_sequence, right_sequence)
+ self.left_masks = np.load(paths.masks_hand_left)
+ self.right_masks = np.load(paths.masks_hand_right)
+
+ # Apply depth-based pose refinement if enabled
+ if self.apply_depth_alignment:
+ left_sequence = self._process_all_frames_depth_alignment(imgs_rgb, left_hand_detected, "left", left_sequence)
+ right_sequence = self._process_all_frames_depth_alignment(imgs_rgb, right_hand_detected, "right", right_sequence)
+
+ # Save processed sequences and generate visualizations
+ self._save_results(paths, left_sequence, right_sequence)
+
+ def _process_all_frames(self, imgs_rgb: np.ndarray, bboxes: np.ndarray,
+ hand_detections: np.ndarray, hand_side: str) -> HandSequence:
+ """
+ Process all frames in a video sequence to extract hand poses.
+
+ This method iterates through all video frames, applying pose estimation
+ where hands are detected and creating empty frames where they are not.
+ It maintains temporal consistency and provides quality filtering.
+
+ Args:
+ imgs_rgb: RGB video frames, shape (num_frames, height, width, 3)
+ bboxes: Hand bounding boxes per frame, shape (num_frames, 4)
+ hand_detections: Boolean flags indicating valid detections per frame
+ hand_side: "left" or "right" to specify which hand is being processed
+
+ Returns:
+ HandSequence object containing processed pose data for all frames
+ """
+ sequence = HandSequence()
+
+ for img_idx in tqdm(range(len(imgs_rgb)), disable=False, leave=False):
+ if not hand_detections[img_idx]:
+ # Create empty frame for missing detections
+ sequence.add_frame(HandFrame.create_empty_frame(
+ frame_idx=img_idx,
+ img_rgb=imgs_rgb[img_idx],
+ ))
+ continue
+
+ # Process frame with detected hand
+ frame_data = self._process_frame(img_idx, imgs_rgb[img_idx], bboxes[img_idx],
+ hand_side)
+ sequence.add_frame(frame_data)
+
+ return sequence
+
+ def _process_frame(self, img_idx: int, img_rgb: np.ndarray, bbox: np.ndarray,
+ hand_side: str, view: bool = False) -> HandFrame:
+ """
+ Process a single frame to extract hand pose and validate quality.
+
+ This method applies HaMeR pose estimation within the detected bounding box
+ and performs quality checks to ensure the pose is suitable for downstream
+ processing. Poor quality poses (e.g., hands too close to image edges) are
+ rejected to maintain data quality.
+
+ Args:
+ img_idx: Index of the current frame
+ img_rgb: RGB image data for this frame
+ bbox: Hand bounding box coordinates [x1, y1, x2, y2]
+ hand_side: "left" or "right" specifying which hand is being processed
+ view: Whether to display debug visualizations
+
+ Returns:
+ HandFrame object containing pose data or empty frame if quality is poor
+ """
+ try:
+ # Apply HaMeR pose estimation within bounding box
+ processed_data = self._process_image_with_hamer(img_rgb, bbox[None,...], hand_side, img_idx, view=view)
+
+ # Quality check: reject poses where keypoints are too close to image edges
+ if self.are_kpts_too_close_to_margin(processed_data["kpts_2d"], self.W, self.H, margin=5, threshold=0.1):
+ logger.error(f"Error processing frame {img_idx}: Edge hand")
+ return HandFrame.create_empty_frame(
+ frame_idx=img_idx,
+ img_rgb=img_rgb,
+ )
+
+ # Create frame with validated pose data
+ frame_data = HandFrame(
+ frame_idx=img_idx,
+ hand_detected=True,
+ img_rgb=img_rgb,
+ img_hamer=processed_data["img_hamer"],
+ kpts_2d=processed_data["kpts_2d"],
+ kpts_3d=processed_data["kpts_3d"],
+ )
+
+ return frame_data
+
+ except Exception as e:
+ logger.error(f"Error processing frame {img_idx}: {str(e)}")
+ return HandFrame.create_empty_frame(
+ frame_idx=img_idx,
+ img_rgb=img_rgb,
+ )
+
+ def are_kpts_too_close_to_margin(self, kpts_2d: np.ndarray, img_width: int, img_height: int,
+ margin: int = 20, threshold: float = 0.5) -> bool:
+ """
+ Filter hand keypoints based on proximity to image edges.
+
+ This quality check rejects hand poses where too many keypoints are near
+ the image boundaries, which typically indicates partial occlusion or
+ tracking errors that would lead to poor pose estimates.
+
+ Args:
+ kpts_2d: 2D keypoint positions, shape (N, 2) where N is number of keypoints
+ img_width: Image width in pixels
+ img_height: Image height in pixels
+ margin: Distance from edge (in pixels) to consider "too close"
+ threshold: Fraction of keypoints that triggers rejection (e.g., 0.5 = 50%)
+
+ Returns:
+ True if hand should be rejected due to edge proximity, False otherwise
+ """
+ x = kpts_2d[:, 0]
+ y = kpts_2d[:, 1]
+
+ # Create boolean mask for keypoints near any image edge
+ near_edge = (
+ (x < margin) |
+ (y < margin) |
+ (x > img_width - margin) |
+ (y > img_height - margin)
+ )
+
+ frac_near_edge = np.mean(near_edge) # Fraction of keypoints near edge
+ return frac_near_edge > threshold
+
+ def _save_results(self, paths: Paths, left_sequence: HandSequence, right_sequence: HandSequence) -> None:
+ """
+ Save processed hand sequences and generate visualization videos.
+
+ Args:
+ paths: Paths object containing output file locations
+ left_sequence: Processed left hand pose sequence
+ right_sequence: Processed right hand pose sequence
+ """
+ # Create output directory
+ if not os.path.exists(getattr(paths, f"hand_processor")):
+ os.makedirs(getattr(paths, f"hand_processor"))
+
+ # Save hand sequence data in compressed format
+ left_sequence.save(getattr(paths, f"hand_data_left"))
+ right_sequence.save(getattr(paths, f"hand_data_right"))
+
+ # Save RGB frames for reference
+ media.write_video(getattr(paths, f"video_rgb_imgs"), left_sequence.imgs_rgb, fps=10, codec="ffv1")
+
+ # Load additional visualization components
+ imgs_bbox = media.read_video(getattr(paths, f"video_bboxes"))
+
+ # Load segmentation visualization if available
+ if os.path.exists(getattr(paths, f"video_sam_arm")):
+ imgs_sam = media.read_video(getattr(paths, f"video_sam_arm"))
+ else:
+ imgs_sam = np.zeros((len(left_sequence.imgs_rgb), left_sequence.imgs_rgb[0].shape[0], left_sequence.imgs_rgb[0].shape[1], 3))
+
+ # Create comprehensive annotation video showing all processing stages
+ annot_imgs = []
+ for idx in range(len(left_sequence.imgs_rgb)):
+ img_hamer_left = left_sequence.imgs_hamer[idx]
+ img_hamer_right = right_sequence.imgs_hamer[idx]
+ img_bbox = imgs_bbox[idx]
+ img_sam = imgs_sam[idx]
+
+ # Combine visualizations in 2x2 grid: [bbox, sam] on top, [left_hand, right_hand] on bottom
+ annot_img = np.vstack((np.hstack((img_bbox, img_sam)), np.hstack((img_hamer_left, img_hamer_right)))).astype(np.uint8)
+ annot_imgs.append(annot_img)
+
+ # Save comprehensive visualization video
+ media.write_video(getattr(paths, f"video_annot"), np.array(annot_imgs), fps=10, codec="h264") # mp4
+
+ def _create_hand_mesh(self, hamer_out: Dict[str, Any]) -> trimesh.Trimesh:
+ """
+ Create a 3D triangle mesh from HaMeR pose estimation output.
+
+ Args:
+ hamer_out: HaMeR output dictionary containing vertex positions
+
+ Returns:
+ Trimesh object representing the hand mesh
+ """
+ return trimesh.Trimesh(hamer_out["verts"].copy(), self.detector_hamer.faces_left.copy(), process=False)
+
+ def _get_hand_masks(self, data_sub_folder: str, hamer_data_left: HandSequence, hamer_data_right: HandSequence) -> None:
+ """
+ Generate hand segmentation masks using processed pose data.
+
+ This method integrates with the segmentation processor to generate
+ detailed hand masks that can be used for depth-based pose refinement.
+
+ Args:
+ data_sub_folder: Path to demonstration data folder
+ hamer_data_left: Processed left hand sequence for guidance
+ hamer_data_right: Processed right hand sequence for guidance
+ """
+ hamer_data = {
+ "left": hamer_data_left,
+ "right": hamer_data_right
+ }
+ self.hand_mask_processor.process_one_demo(data_sub_folder, hamer_data)
+
+ @staticmethod
+ def _get_visible_pts_from_hamer(detector_hamer: DetectorHamer, hamer_out: Dict[str, Any], mesh: trimesh.Trimesh,
+ img_depth: np.ndarray, cam_intrinsics: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Identify visible hand vertices and their corresponding depth points.
+
+ Args:
+ detector_hamer: HaMeR detector instance for coordinate projections
+ hamer_out: HaMeR output containing pose estimates and camera parameters
+ mesh: 3D hand mesh generated from HaMeR output
+ img_depth: Depth image corresponding to the RGB frame
+ cam_intrinsics: Camera intrinsic parameters for 3D projection
+
+ Returns:
+ Tuple of (visible_points_3d, visible_hamer_vertices):
+ - visible_points_3d: 3D points from depth image at visible mesh locations
+ - visible_hamer_vertices: Corresponding vertices from the HaMeR mesh
+ """
+ # Perform ray-casting to identify visible mesh vertices
+ visible_hamer_vertices, _ = get_visible_points(mesh, origin=np.array([0,0,0]))
+
+ # Project 3D vertices to 2D image coordinates
+ visible_points_2d = detector_hamer.project_3d_kpt_to_2d(
+ (visible_hamer_vertices-hamer_out["T_cam_pred"].cpu().numpy()).astype(np.float32),
+ hamer_out["img_w"], hamer_out["img_h"], hamer_out["scaled_focal_length"],
+ hamer_out["camera_center"], hamer_out["T_cam_pred"])
+
+ # Filter out points that fall outside the depth image boundaries
+ original_visible_points_2d = visible_points_2d.copy()
+
+ # Create valid region mask (note: depth indexing is [y, x])
+ valid_mask = ((original_visible_points_2d[:, 0] < img_depth.shape[1]) &
+ (original_visible_points_2d[:, 1] < img_depth.shape[0]))
+
+ visible_points_2d = visible_points_2d[valid_mask]
+ visible_hamer_vertices = visible_hamer_vertices[valid_mask]
+
+ # Convert 2D depth pixels to 3D points using camera intrinsics
+ visible_points_3d = get_3D_points_from_pixels(visible_points_2d, img_depth, cam_intrinsics)
+
+ return visible_points_3d, visible_hamer_vertices
+
+ @staticmethod
+ def _get_transformation_estimate(visible_points_3d: np.ndarray,
+ visible_hamer_vertices: np.ndarray,
+ pcd: o3d.geometry.PointCloud) -> Tuple[np.ndarray, o3d.geometry.PointCloud]:
+ """
+ Estimate transformation to align HaMeR mesh with observed point cloud.
+
+ This method uses Iterative Closest Point (ICP) registration to find the
+ optimal transformation that aligns the visible parts of the predicted
+ hand mesh with the point cloud extracted from depth and segmentation data.
+
+ Args:
+ visible_points_3d: 3D points from depth image at mesh locations
+ visible_hamer_vertices: Corresponding vertices from HaMeR mesh
+ pcd: Point cloud from segmentation and depth data
+
+ Returns:
+ Tuple of (transformation_matrix, aligned_mesh_pointcloud):
+ - transformation_matrix: 4x4 transformation to align mesh with depth
+ - aligned_mesh_pointcloud: Transformed mesh point cloud after alignment
+ """
+ # Get initial transformation estimate using median translation
+ T_0 = HandBaseProcessor._get_initial_transformation_estimate(visible_points_3d, visible_hamer_vertices)
+
+ # Create point cloud from visible mesh vertices
+ visible_hamer_pcd = get_pcd_from_points(visible_hamer_vertices, colors=np.ones_like(visible_hamer_vertices) * [0, 1, 0])
+
+ try:
+ # Apply ICP registration for fine alignment
+ aligned_hamer_pcd, T = icp_registration(visible_hamer_pcd, pcd, voxel_size=0.005, init_transform=T_0)
+ except Exception as e:
+ logger.error(f"ICP registration failed: {e}")
+ return T_0, visible_hamer_pcd
+
+ return T, aligned_hamer_pcd
+
+ @staticmethod
+ def _get_initial_transformation_estimate(visible_points_3d: np.ndarray,
+ visible_hamer_vertices: np.ndarray) -> np.ndarray:
+ """
+ Compute initial transformation estimate for mesh-to-depth alignment.
+
+ This method provides a coarse alignment between the HaMeR prediction and
+ the depth-based point cloud using median translation. It assumes that
+ orientation is approximately correct and only translation correction is needed.
+
+ Args:
+ visible_points_3d: 3D points from depth image
+ visible_hamer_vertices: Corresponding HaMeR mesh vertices
+
+ Returns:
+ 4x4 transformation matrix with estimated translation
+ """
+ # Calculate median translation between corresponding point sets
+ translation = np.nanmedian(visible_points_3d - visible_hamer_vertices, axis=0)
+
+ # Create transformation matrix (identity rotation + translation)
+ T_0 = np.eye(4)
+ if not np.isnan(translation).any():
+ T_0[:3, 3] = translation
+
+ return T_0
+
+
+class Hand2DProcessor(HandBaseProcessor):
+ """
+ 2D hand pose processor optimized for speed and RGB-only operation.
+
+ This processor focuses on extracting 2D hand poses and basic 3D estimates
+ without depth sensor integration. It's designed for applications where
+ depth sensors are not available.
+ """
+ def __init__(self, args: argparse.Namespace) -> None:
+ """
+ Initialize 2D hand processor with RGB-only configuration.
+
+ Args:
+ args: Command line arguments for processor configuration
+ """
+ super().__init__(args)
+
+ def _process_image_with_hamer(self, img_rgb: np.ndarray, bboxes: np.ndarray, hand_side: str,
+ img_idx: int, view: bool = False) -> Dict[str, Any]:
+ """
+ Process RGB image with HaMeR for 2D pose estimation.
+
+ Args:
+ img_rgb: RGB image to process
+ bboxes: Hand bounding boxes for pose estimation guidance
+ hand_side: "left" or "right" specifying which hand to process
+ img_idx: Frame index for debugging and logging
+ view: Whether to display debug visualizations
+
+ Returns:
+ Dictionary containing:
+ - img_hamer: Annotated image with pose visualization
+ - kpts_3d: Estimated 3D keypoints
+ - kpts_2d: 2D keypoint projections in image coordinates
+
+ Raises:
+ ValueError: If no valid hand pose is detected in the image
+ """
+ # Configure HaMeR for target hand side
+ is_right = np.array([hand_side_dict[str(hand_side)]*True]*len(bboxes))
+
+ # Apply HaMeR pose estimation
+ hamer_out = self.detector_hamer.detect_hand_keypoints(
+ img_rgb,
+ hand_side=hand_side,
+ bboxes=bboxes,
+ is_right=is_right,
+ camera_params=self.intrinsics_dict,
+ visualize=False
+ )
+
+ if hamer_out is None or not hamer_out.get("success", False):
+ raise ValueError("No hand detected in image")
+
+ return {
+ "img_hamer": hamer_out["annotated_img"][:,:,::-1], # Convert BGR to RGB
+ "kpts_3d": hamer_out["kpts_3d"],
+ "kpts_2d": hamer_out['kpts_2d']
+ }
+
+class Hand3DProcessor(HandBaseProcessor):
+ """
+ 3D hand pose processor with depth-based refinement capabilities.
+
+ This processor provides more accurate 3D hand poses by combining HaMeR
+ estimation with depth sensor data and hand segmentation. It uses point cloud
+ registration techniques to refine the initial pose estimates, resulting in
+ poses that are better aligned with the physical environment.
+
+ Processing Enhancements:
+ - Mesh generation from HaMeR output for visibility analysis
+ - Hand segmentation using SAM2 for accurate depth extraction
+ - ICP-based alignment between predicted mesh and observed point cloud
+ """
+ def __init__(self, args: argparse.Namespace) -> None:
+ """
+ Initialize 3D hand processor with depth refinement capabilities.
+
+ Args:
+ args: Command line arguments containing depth processing configuration
+ """
+ super().__init__(args)
+ self.args = args
+
+ # Storage for HaMeR outputs needed for depth alignment
+ self.hamer_out_dict: Dict[str, Dict[int, Dict[str, Any]]] = {
+ "left": defaultdict(dict),
+ "right": defaultdict(dict)
+ }
+
+ # Enable advanced processing features
+ self.process_hand_masks = True
+ self.apply_depth_alignment = True
+ self.hand_mask_processor = HandSegmentationProcessor(self.args)
+
+ def _process_image_with_hamer(self, img_rgb: np.ndarray, bboxes: np.ndarray, hand_side: str,
+ img_idx: int, view: bool = False) -> Dict[str, Any]:
+ """
+ Process RGB image with HaMeR optimized for subsequent depth refinement.
+
+ This method applies HaMeR pose estimation configured for 3D processing,
+ storing intermediate results needed for later depth-based refinement.
+
+ Args:
+ img_rgb: RGB image to process
+ bboxes: Hand bounding boxes for pose estimation guidance
+ hand_side: "left" or "right" specifying which hand to process
+ img_idx: Frame index for result storage and debugging
+ view: Whether to display debug visualizations
+
+ Returns:
+ Dictionary containing pose estimation results
+
+ Raises:
+ ValueError: If no valid hand pose is detected in the image
+ """
+ # Configure HaMeR for target hand side
+ is_right = np.array([hand_side_dict[str(hand_side)]*True]*len(bboxes))
+
+ # Apply HaMeR with 2D keypoint focus (3D refinement happens later)
+ hamer_out = self.detector_hamer.detect_hand_keypoints(
+ img_rgb,
+ hand_side=hand_side,
+ bboxes=bboxes,
+ is_right=is_right,
+ kpts_2d_only=True, # Initial processing focuses on 2D
+ camera_params=self.intrinsics_dict
+ )
+
+ if hamer_out is None or not hamer_out.get("success", False):
+ raise ValueError("No hand detected in image")
+
+ # Store HaMeR output for later depth alignment processing
+ self.hamer_out_dict[hand_side][img_idx] = hamer_out
+
+ return {
+ "img_hamer": hamer_out["annotated_img"][:,:,::-1], # Convert BGR to RGB
+ "kpts_3d": hamer_out["kpts_3d"],
+ "kpts_2d": hamer_out['kpts_2d']
+ }
+
+ def _process_all_frames_depth_alignment(self, imgs_rgb: np.ndarray, hand_detections: np.ndarray,
+ hand_side: str, sequence: Optional[HandSequence] = None) -> HandSequence:
+ """
+ Apply depth-based refinement to all frames in the sequence.
+
+ This method performs the depth alignment stage of processing, using
+ segmentation masks and depth data to refine the initial HaMeR pose
+ estimates for improved 3D accuracy.
+
+ Args:
+ imgs_rgb: RGB video frames for reference
+ hand_detections: Boolean flags indicating frames with valid detections
+ hand_side: "left" or "right" specifying which hand to process
+ sequence: HandSequence containing initial pose estimates to refine
+
+ Returns:
+ HandSequence with refined 3D poses aligned to depth data
+ """
+ for img_idx in tqdm(range(len(imgs_rgb)), disable=False, leave=False):
+ if not hand_detections[img_idx]:
+ continue
+
+ # Apply depth-based refinement to this frame
+ frame_data = sequence.get_frame(img_idx)
+ frame_data.kpts_3d = self._depth_alignment(img_idx, hand_side, imgs_rgb[img_idx])
+ sequence.modify_frame(img_idx, frame_data)
+
+ return sequence
+
+ def _depth_alignment(self, img_idx: int, hand_side: str, img_rgb: np.ndarray) -> np.ndarray:
+ """
+ Perform depth-based pose refinement for a single frame.
+
+ Algorithm Steps:
+ 1. Extract depth image and segmentation mask for the frame
+ 2. Obtain 3D hand mesh from HaMeR output
+ 3. Create point cloud from segmented depth region
+ 4. Identify visible mesh vertices through ray casting
+ 5. Apply ICP registration between mesh and point cloud
+ 6. Transform original keypoints using computed alignment
+
+ Args:
+ img_idx: Index of the frame to process
+ hand_side: "left" or "right" specifying which hand to process
+ img_rgb: RGB image for reference (used in point cloud generation)
+
+ Returns:
+ Refined 3D keypoint positions aligned with depth data
+ """
+ # Load frame-specific data
+ img_depth = self.imgs_depth[img_idx]
+ mask = self.left_masks[img_idx] if hand_side == "left" else self.right_masks[img_idx]
+ hamer_out = self.hamer_out_dict[hand_side][img_idx]
+
+ # Create 3D hand mesh from HaMeR pose estimate
+ mesh = self._create_hand_mesh(hamer_out)
+
+ # Generate point cloud from depth image within segmented hand region
+ pcd = get_point_cloud_of_segmask(mask, img_depth, img_rgb, self.intrinsics_dict, visualize=False)
+
+ # Identify visible mesh vertices and corresponding depth points
+ visible_points_3d, visible_hamer_vertices = self._get_visible_pts_from_hamer(
+ self.detector_hamer,
+ hamer_out,
+ mesh,
+ img_depth,
+ self.intrinsics_dict
+ )
+
+ # Compute optimal transformation using ICP registration
+ T, _ = self._get_transformation_estimate(visible_points_3d, visible_hamer_vertices, pcd)
+
+ # Apply transformation to refine original keypoint positions
+ kpts_3d = transform_pts(hamer_out["kpts_3d"], T)
+
+ return kpts_3d
\ No newline at end of file
diff --git a/phantom/phantom/processors/handinpaint_processor.py b/phantom/phantom/processors/handinpaint_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ac9d948f497c8ca6b74fb84b646263aa82025ec
--- /dev/null
+++ b/phantom/phantom/processors/handinpaint_processor.py
@@ -0,0 +1,485 @@
+"""
+Hand Inpainting Processor Module
+
+This module removes human hands from demonstration videos using the E2FGVI model.
+
+Paper:
+Towards An End-to-End Framework for Flow-Guided Video Inpainting
+https://github.com/MCG-NKU/E2FGVI.git
+
+Processing Pipeline:
+1. Load pre-trained E2FGVI model and initialize GPU processing
+2. Read input video frames and corresponding hand segmentation masks
+3. Process frames in batches with neighboring temporal context
+4. Apply mask-guided inpainting to remove hand regions
+5. Verify complete processing and handle any missed frames
+6. Save final hand-free video for robot learning applications
+"""
+
+import cv2
+from PIL import Image
+import numpy as np
+import os
+from pathlib import Path
+from tqdm import tqdm
+import torch
+import mediapy as media
+import logging
+import gc
+from typing import List, Tuple, Optional, Any, Union
+
+from phantom.processors.base_processor import BaseProcessor
+from phantom.utils.data_utils import get_parent_folder_of_package
+from E2FGVI.model.e2fgvi_hq import InpaintGenerator # type: ignore
+from E2FGVI.core.utils import to_tensors # type: ignore
+
+DEFAULT_CHECKPOINT = 'E2FGVI/release_model/E2FGVI-HQ-CVPR22.pth'
+
+logger = logging.getLogger(__name__)
+
+class HandInpaintProcessor(BaseProcessor):
+ """
+ Hand inpainting processor for removing human hands from demonstration videos.
+
+ Attributes:
+ model: E2FGVI neural network model for video inpainting
+ device: GPU/CPU device for model execution
+ ref_length (int): Spacing between reference frames for temporal consistency
+ num_ref (int): Number of reference frames to use (-1 for automatic)
+ neighbor_stride (int): Spacing between neighboring frames in temporal context
+ batch_size (int): Number of frame groups to process simultaneously
+ scale_factor (int): Resolution scaling factor for processing optimization
+ """
+
+ def __init__(self, args: Any) -> None:
+ """
+ Initialize the hand inpainting processor with E2FGVI model and parameters.
+
+ Args:
+ args: Command line arguments containing processing configuration
+ including scale factor and other inpainting parameters
+ """
+ super().__init__(args)
+
+ # Load pre-trained E2FGVI model
+ root_dir = get_parent_folder_of_package("E2FGVI")
+ checkpoint_path = Path(root_dir, DEFAULT_CHECKPOINT)
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Initialize and load the inpainting model
+ self.model = InpaintGenerator().to(self.device)
+ data = torch.load(checkpoint_path, map_location=self.device)
+ self.model.load_state_dict(data)
+ self.model.eval()
+
+ # Configure temporal processing parameters
+ self.ref_length: int = 20 # Spacing between reference frames
+ self.num_ref: int = -1 # Number of reference frames (-1 = automatic)
+ self.neighbor_stride: int = 5 # Stride for neighboring frame selection
+
+ # Configure batch processing parameters for memory optimization
+ self.batch_size: int = 10 # Number of frame groups per batch
+ self.scale_factor: int = getattr(args, 'scale_factor', 2) # Resolution scaling
+
+ def _clear_gpu_memory(self) -> None:
+ """Clear GPU memory cache and trigger garbage collection."""
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def process_one_demo(self, data_sub_folder: str) -> None:
+ """
+ Process a single demonstration video to remove hand regions.
+
+ Args:
+ data_sub_folder: Path to demonstration data folder containing
+ input video and hand segmentation masks
+ """
+ save_folder = self.get_save_folder(data_sub_folder)
+ paths = self.get_paths(save_folder)
+ if not os.path.exists(paths.inpaint_processor):
+ os.makedirs(paths.inpaint_processor)
+
+ self._process_frames(paths)
+
+ def _process_frames(self, paths: Any) -> None:
+ """
+ Process all video frames to remove hand regions using E2FGVI inpainting.
+
+ Args:
+ paths: Paths object containing input video and mask file locations
+ """
+ # Load and prepare video frames
+ frames = self._load_and_prepare_frames(paths)
+ video_length = len(frames)
+ logger.info(f"Processing {video_length} frames")
+
+ # Initialize tracking arrays for processed frames
+ comp_frames: List[Optional[np.ndarray]] = [None] * video_length
+ processed_frame_mask: List[bool] = [False] * video_length
+
+ # Process frames in batches with temporal overlap for consistency
+ self._process_frames_in_batches(frames, paths, comp_frames, processed_frame_mask)
+
+ # Handle any missed frames
+ self._process_missed_frames(frames, paths, comp_frames, processed_frame_mask)
+
+ # Final verification and save
+ self._verify_and_save_results(comp_frames, paths)
+
+ def _load_and_prepare_frames(self, paths: Any) -> List[Image.Image]:
+ """Load video frames and prepare them for processing."""
+ frames = self.read_frame_from_videos(paths.video_rgb_imgs)
+
+ # Calculate output dimensions based on configuration
+ h, w = frames[0].height, frames[0].width
+
+ if self.epic:
+ size = (w, h)
+ else:
+ if self.square:
+ output_resolution = np.array([self.output_resolution, self.output_resolution])
+ else:
+ output_resolution = np.array([int(w/h*self.output_resolution), self.output_resolution])
+ output_resolution = output_resolution.astype(np.int32)
+ size = output_resolution
+ frames, size = self.resize_frames(frames, size)
+
+ return frames
+
+ def _process_frames_in_batches(self, frames: List[Image.Image], paths: Any,
+ comp_frames: List[Optional[np.ndarray]],
+ processed_frame_mask: List[bool]) -> None:
+ """Process frames in batches with temporal overlap."""
+ video_length = len(frames)
+ h, w = frames[0].height, frames[0].width
+
+ for batch_start in tqdm(range(0, video_length, self.batch_size * self.neighbor_stride),
+ desc="Processing batches"):
+ batch_end = min(batch_start + self.batch_size * self.neighbor_stride + self.neighbor_stride, video_length)
+
+ # Prepare batch data
+ batch_data = self._prepare_batch_data(frames, paths, batch_start, batch_end, h, w)
+
+ # Process frames within batch
+ self._process_batch_frames(frames, batch_data, batch_start, batch_end,
+ comp_frames, processed_frame_mask, h, w)
+
+ # Clean up batch memory
+ del batch_data['batch_imgs'], batch_data['batch_masks']
+ self._clear_gpu_memory()
+
+ def _prepare_batch_data(self, frames: List[Image.Image], paths: Any,
+ batch_start: int, batch_end: int, h: int, w: int) -> dict:
+ """Prepare batch data including frames, masks, and binary masks."""
+ batch_frames = frames[batch_start:batch_end]
+ batch_imgs = to_tensors()(batch_frames).unsqueeze(0).to(self.device) * 2 - 1
+
+ batch_masks = self.read_mask(paths.masks_arm, (w, h))[batch_start:batch_end]
+ batch_masks = to_tensors()(batch_masks).unsqueeze(0).to(self.device)
+
+ binary_masks = self._create_binary_masks(paths.masks_arm, batch_start, batch_end, w, h)
+
+ return {
+ 'batch_imgs': batch_imgs,
+ 'batch_masks': batch_masks,
+ 'binary_masks': binary_masks
+ }
+
+ def _create_binary_masks(self, mask_path: str, batch_start: int, batch_end: int,
+ w: int, h: int) -> List[np.ndarray]:
+ """Create binary masks for the batch."""
+ masks = self.read_mask(mask_path, (w, h))[batch_start:batch_end]
+ binary_masks = []
+
+ for mask in masks:
+ mask_array = np.array(mask)
+ binary_mask = np.expand_dims((mask_array != 0).astype(np.uint8), 2)
+ binary_mask = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST)
+ binary_mask = np.expand_dims(binary_mask, 2)
+ binary_masks.append(binary_mask)
+
+ return binary_masks
+
+ def _process_batch_frames(self, frames: List[Image.Image], batch_data: dict,
+ batch_start: int, batch_end: int,
+ comp_frames: List[Optional[np.ndarray]],
+ processed_frame_mask: List[bool], h: int, w: int) -> None:
+ """Process individual frames within a batch."""
+ stride = max(1, self.neighbor_stride if batch_start + self.batch_size * self.neighbor_stride < len(frames) else 1)
+
+ for frame_idx in range(batch_start, batch_end, stride):
+ neighbor_ids = self._get_neighbor_ids(frame_idx, batch_start, batch_end)
+ ref_ids = self.get_ref_index(frame_idx, neighbor_ids, batch_end)
+
+ if not neighbor_ids:
+ continue
+
+ # Convert to batch-relative indices
+ batch_neighbor_ids = [i - batch_start for i in neighbor_ids]
+ batch_ref_ids = [i - batch_start for i in ref_ids if batch_start <= i < batch_end]
+
+ # Process frame with temporal context
+ self._process_single_frame(frames, batch_data, neighbor_ids, batch_neighbor_ids,
+ batch_ref_ids, comp_frames, processed_frame_mask, h, w)
+
+ self._clear_gpu_memory()
+
+ def _get_neighbor_ids(self, frame_idx: int, batch_start: int, batch_end: int) -> List[int]:
+ """Get neighboring frame indices for temporal context."""
+ return list(range(
+ max(batch_start, frame_idx - self.neighbor_stride),
+ min(batch_end, frame_idx + self.neighbor_stride + 1)
+ ))
+
+ def _process_single_frame(self, frames: List[Image.Image], batch_data: dict,
+ neighbor_ids: List[int], batch_neighbor_ids: List[int],
+ batch_ref_ids: List[int], comp_frames: List[Optional[np.ndarray]],
+ processed_frame_mask: List[bool], h: int, w: int) -> None:
+ """Process a single frame with its temporal context."""
+ batch_start = neighbor_ids[0] - batch_neighbor_ids[0]
+
+ # Select relevant frames and masks
+ selected_imgs = batch_data['batch_imgs'][:, batch_neighbor_ids + batch_ref_ids, :, :, :]
+ selected_masks = batch_data['batch_masks'][:, batch_neighbor_ids + batch_ref_ids, :, :]
+
+ with torch.no_grad():
+ # Apply masks and generate inpainted frames
+ masked_imgs = selected_imgs * (1 - selected_masks)
+ masked_imgs = self._pad_images(masked_imgs, h, w)
+
+ pred_imgs, _ = self.model(masked_imgs, len(batch_neighbor_ids))
+ pred_imgs = (pred_imgs[:, :, :h, :w] + 1) / 2
+ pred_imgs = (pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
+
+ # Composite with original background
+ for i, idx in enumerate(neighbor_ids):
+ binary_mask = batch_data['binary_masks'][idx - batch_start]
+ original_frame = np.array(frames[idx])
+
+ inpainted_frame = (pred_imgs[i] * binary_mask +
+ original_frame * (1 - binary_mask))
+
+ # Average with previous results if frame was already processed
+ if comp_frames[idx] is None:
+ comp_frames[idx] = inpainted_frame
+ else:
+ comp_frames[idx] = ((comp_frames[idx].astype(np.float32) +
+ inpainted_frame.astype(np.float32)) / 2).astype(np.uint8)
+ processed_frame_mask[idx] = True
+
+ def _process_missed_frames(self, frames: List[Image.Image], paths: Any,
+ comp_frames: List[Optional[np.ndarray]],
+ processed_frame_mask: List[bool]) -> None:
+ """Process any frames that were missed during batch processing."""
+ unprocessed_frames = [i for i, processed in enumerate(processed_frame_mask) if not processed]
+
+ if not unprocessed_frames:
+ return
+
+ logger.warning(f"Found {len(unprocessed_frames)} unprocessed frames at indices: {unprocessed_frames}")
+
+ # Determine processing context for missed frames
+ start_idx, end_idx = self._get_missed_frame_context(unprocessed_frames, processed_frame_mask, len(frames))
+
+ logger.info(f"Processing missed frames from {start_idx} to {end_idx}")
+ self._process_missed_frame_sequence(frames, paths, unprocessed_frames,
+ start_idx, end_idx, comp_frames, processed_frame_mask)
+
+ def _get_missed_frame_context(self, unprocessed_frames: List[int],
+ processed_frame_mask: List[bool], video_length: int) -> Tuple[int, int]:
+ """Get the context range for processing missed frames."""
+ last_processed_idx = max([i for i, processed in enumerate(processed_frame_mask[:unprocessed_frames[0]])
+ if processed], default=-1)
+ if last_processed_idx == -1:
+ last_processed_idx = 0
+
+ next_processed_idx = min([i for i, processed in enumerate(processed_frame_mask[unprocessed_frames[-1]:],
+ start=unprocessed_frames[-1]) if processed], default=video_length)
+
+ start_idx = max(0, last_processed_idx - self.neighbor_stride)
+ end_idx = min(video_length, next_processed_idx + self.neighbor_stride)
+
+ return start_idx, end_idx
+
+ def _process_missed_frame_sequence(self, frames: List[Image.Image], paths: Any,
+ unprocessed_frames: List[int], start_idx: int, end_idx: int,
+ comp_frames: List[Optional[np.ndarray]],
+ processed_frame_mask: List[bool]) -> None:
+ """Process the sequence containing missed frames."""
+ h, w = frames[0].height, frames[0].width
+
+ # Prepare sequence data
+ batch_frames = frames[start_idx:end_idx]
+ batch_imgs = to_tensors()(batch_frames).unsqueeze(0).to(self.device) * 2 - 1
+
+ batch_masks = self.read_mask(paths.masks_arm, (w, h))[start_idx:end_idx]
+ batch_masks = to_tensors()(batch_masks).unsqueeze(0).to(self.device)
+
+ binary_masks = self._create_binary_masks(paths.masks_arm, start_idx, end_idx, w, h)
+
+ # Process each missed frame
+ for idx in tqdm(unprocessed_frames, desc="Processing missed frames"):
+ self._process_missed_single_frame(frames, batch_imgs, batch_masks, binary_masks,
+ idx, start_idx, end_idx, comp_frames, processed_frame_mask, h, w)
+
+ del batch_imgs, batch_masks
+ self._clear_gpu_memory()
+
+ def _process_missed_single_frame(self, frames: List[Image.Image], batch_imgs: torch.Tensor,
+ batch_masks: torch.Tensor, binary_masks: List[np.ndarray],
+ frame_idx: int, start_idx: int, end_idx: int,
+ comp_frames: List[Optional[np.ndarray]],
+ processed_frame_mask: List[bool], h: int, w: int) -> None:
+ """Process a single missed frame."""
+ relative_start = frame_idx - start_idx
+ neighbor_ids = list(range(
+ max(0, relative_start - self.neighbor_stride),
+ min(end_idx - start_idx, relative_start + self.neighbor_stride + 1)
+ ))
+ ref_ids = self.get_ref_index(relative_start, neighbor_ids, end_idx - start_idx)
+
+ with torch.no_grad():
+ selected_imgs = batch_imgs[:, neighbor_ids + ref_ids, :, :, :]
+ selected_masks = batch_masks[:, neighbor_ids + ref_ids, :, :]
+
+ masked_imgs = selected_imgs * (1 - selected_masks)
+ masked_imgs = self._pad_images(masked_imgs, h, w)
+
+ pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
+ pred_imgs = (pred_imgs[:, :, :h, :w] + 1) / 2
+ pred_imgs = (pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
+
+ relative_idx = frame_idx - start_idx - neighbor_ids[0]
+ binary_mask = binary_masks[frame_idx - start_idx]
+ original_frame = np.array(frames[frame_idx])
+
+ inpainted_frame = (pred_imgs[relative_idx] * binary_mask +
+ original_frame * (1 - binary_mask))
+ comp_frames[frame_idx] = inpainted_frame
+ processed_frame_mask[frame_idx] = True
+
+ def _verify_and_save_results(self, comp_frames: List[Optional[np.ndarray]], paths: Any) -> None:
+ """Verify all frames were processed and save the final video."""
+ missing_frames = [i for i, frame in enumerate(comp_frames)
+ if frame is None or (isinstance(frame, np.ndarray) and frame.size == 0)]
+
+ if missing_frames:
+ raise RuntimeError(f"Still found unprocessed frames after cleanup: {missing_frames}")
+
+ logger.info("Successfully processed all frames")
+
+ # Save final inpainted video
+ media.write_video(paths.video_human_inpaint, comp_frames, fps=15, codec="ffv1")
+
+ def get_ref_index(self, f: int, neighbor_ids: List[int], length: int) -> List[int]:
+ """
+ Select reference frame indices for temporal consistency.
+
+ Args:
+ f: Current frame index
+ neighbor_ids: List of neighboring frame indices
+ length: Total length of the sequence
+
+ Returns:
+ List of reference frame indices for temporal consistency
+ """
+ if self.num_ref == -1:
+ # Automatic reference selection: every ref_length frames not in neighbors
+ ref_index = [
+ i for i in range(0, length, self.ref_length)
+ if i not in neighbor_ids
+ ]
+ else:
+ # Limited reference selection: specific number around current frame
+ ref_index = []
+ for i in range(max(0, f - self.ref_length * (self.num_ref // 2)),
+ min(length, f + self.ref_length * (self.num_ref // 2)) + 1,
+ self.ref_length):
+ if i not in neighbor_ids and len(ref_index) < self.num_ref:
+ ref_index.append(i)
+ return ref_index
+
+ @staticmethod
+ def read_mask(mask_path: str, size: Tuple[int, int]) -> List[Image.Image]:
+ """
+ Load and process hand segmentation masks for inpainting guidance.
+
+ Args:
+ mask_path: Path to mask file containing hand segmentation data
+ size: Target size (width, height) for mask resizing
+
+ Returns:
+ List of processed PIL Images containing binary hand masks
+ """
+ masks = []
+ frames_media = np.load(mask_path, allow_pickle=True)
+ frames = [frame for frame in frames_media]
+
+ for mask_frame in frames:
+ # Convert to PIL Image and resize
+ mask_img = Image.fromarray(mask_frame)
+ mask_img = mask_img.resize(size, Image.NEAREST)
+ mask_array = np.array(mask_img.convert('L'))
+
+ # Create binary mask
+ binary_mask = np.array(mask_array > 0).astype(np.uint8)
+
+ # Apply morphological dilation to expand mask boundaries
+ # This helps ensure complete coverage of hand regions
+ dilated_mask = cv2.dilate(binary_mask,
+ cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
+ iterations=4)
+ masks.append(Image.fromarray(dilated_mask * 255))
+ return masks
+
+ @staticmethod
+ def read_frame_from_videos(video_path: str) -> List[Image.Image]:
+ """
+ Load video frames and convert to PIL Images.
+
+ Args:
+ video_path: Path to video file
+
+ Returns:
+ List of PIL Images containing video frames
+ """
+ return [Image.fromarray(frame) for frame in media.read_video(video_path)]
+
+ @staticmethod
+ def resize_frames(frames: List[Image.Image], size: Optional[Tuple[int, int]] = None) -> Tuple[List[Image.Image], Tuple[int, int]]:
+ """
+ Resize video frames to target resolution.
+
+ Args:
+ frames: List of PIL Images to resize
+ size: Target size (width, height), or None to keep original
+
+ Returns:
+ Tuple containing resized frames and final size
+ """
+ return ([f.resize(size) for f in frames], size)
+
+ @staticmethod
+ def _pad_images(img_tensor: torch.Tensor, h: int, w: int) -> torch.Tensor:
+ """
+ Pad image tensor to meet model input requirements.
+
+ Args:
+ img_tensor: Input image tensor to pad
+ h: Original height
+ w: Original width
+
+ Returns:
+ Padded image tensor suitable for model input
+ """
+ # Model requires specific dimension multiples
+ mod_size_h, mod_size_w = 60, 108
+
+ # Calculate required padding
+ h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
+ w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
+
+ # Apply reflection padding to avoid boundary artifacts
+ img_tensor = torch.cat([img_tensor, torch.flip(img_tensor, [3])], 3)[:, :, :, :h + h_pad, :]
+ return torch.cat([img_tensor, torch.flip(img_tensor, [4])], 4)[:, :, :, :, :w + w_pad]
+
diff --git a/phantom/phantom/processors/paths.py b/phantom/phantom/processors/paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3007dd0f567d26cc9d48f874fbc743676a39b30
--- /dev/null
+++ b/phantom/phantom/processors/paths.py
@@ -0,0 +1,219 @@
+"""
+Path management for Phantom.
+"""
+import os
+from dataclasses import dataclass
+from pathlib import Path
+from typing import List, Dict, Optional
+import yaml
+
+from phantom.utils.image_utils import convert_video_to_images
+
+@dataclass
+class Paths:
+ """Data class containing all file paths used by processors."""
+ data_path: Path
+ robot_name: str = "franka"
+
+ def __post_init__(self):
+ """Compute derived paths based on base paths."""
+ # Convert string paths to Path objects if needed
+ if isinstance(self.data_path, str):
+ self.data_path = Path(self.data_path)
+
+ # Validate data path
+ if not self.data_path.exists():
+ raise FileNotFoundError(f"Data path does not exist: {self.data_path}")
+
+ # Videos
+ self.video_left = self.data_path / "video_L.mp4"
+ self.video_right = self.data_path / "video_R.mp4"
+ self.video_rgb_imgs = self.data_path / "video_rgb_imgs.mkv"
+
+ # Image folders
+ self.original_images_folder = self.data_path / "original_images"
+ # self._setup_original_images()
+ self.original_images_folder_reverse = self.data_path / "original_images_reverse"
+ # self._setup_original_images_reverse()
+
+ # Epic annotations
+ self.hand_detection_data = self.data_path / "hand_det.pkl"
+ self.cam_extrinsics_data = self.data_path / "extrinsics.npy"
+
+ # Depth
+ self.depth = self.data_path / "depth.npy"
+
+ # Bbox processor
+ self.bbox_processor = self.data_path / "bbox_processor"
+ self.bbox_data = self.bbox_processor / "bbox_data.npz"
+ self.video_bboxes = self.bbox_processor / "video_bboxes.mkv"
+
+ # Segmentation processor
+ self.segmentation_processor = self.data_path / "segmentation_processor"
+ self.masks_arm = self.segmentation_processor / "masks_arm.npy"
+ self.video_masks_arm = self.segmentation_processor / "video_masks_arm.mkv"
+ self.video_sam_arm = self.segmentation_processor / "video_sam_arm.mkv"
+ for side in ["left", "right"]:
+ setattr(self, f"masks_hand_{side}", self.segmentation_processor / f"masks_hand_{side}.npy")
+ setattr(self, f"video_masks_hand_{side}", self.segmentation_processor / f"video_masks_hand_{side}.mkv")
+ setattr(self, f"video_sam_hand_{side}", self.segmentation_processor / f"video_sam_hand_{side}.mkv")
+
+ # Hand Processor
+ self.hand_processor = self.data_path / f"hand_processor"
+ for side in ["left", "right"]:
+ setattr(self, f"hand_data_{side}", self.hand_processor / f"hand_data_{side}.npz")
+ setattr(self, f"hand_data_3d_{side}", self.hand_processor / f"hand_data_3d_{side}.npz")
+ self.video_annot = self.data_path / "video_annot.mp4"
+
+ # Action processor
+ self.action_processor = self.data_path / "action_processor"
+ for side in ["left", "right"]:
+ setattr(self, f"actions_{side}", self.action_processor / f"actions_{side}.npz")
+
+ # Smoothing processor
+ self.smoothing_processor = self.data_path / f"smoothing_processor"
+ for side in ["left", "right"]:
+ setattr(self, f"smoothed_actions_{side}", self.smoothing_processor / f"smoothed_actions_{side}.npz")
+
+ # Inpaint processor
+ self.inpaint_processor = self.data_path / "inpaint_processor"
+ self.video_overlay = self.data_path / "video_overlay.mkv"
+ self.video_human_inpaint = self.inpaint_processor / "video_human_inpaint.mkv"
+ self.video_inpaint_overlay = self.inpaint_processor / "video_inpaint_overlay.mkv"
+ self.video_birdview = self.inpaint_processor / "video_birdview.mkv"
+ self.training_data = self.inpaint_processor / "training_data.npz"
+
+ def _setup_original_images(self):
+ """Set up original images paths."""
+ convert_video_to_images(self.video_left, self.original_images_folder, square=False)
+ image_paths = sorted(
+ list(self.original_images_folder.glob("*.jpg")),
+ key=lambda x: int(x.stem)
+ )
+ self.original_images = image_paths
+
+ def _setup_original_images_reverse(self):
+ """Set up original images paths."""
+ convert_video_to_images(self.video_left, self.original_images_folder_reverse, square=False, reverse=True)
+ image_paths = sorted(
+ list(self.original_images_folder_reverse.glob("*.jpg")),
+ key=lambda x: int(x.stem)
+ )
+ self.original_images_reverse = image_paths
+
+ def ensure_directories_exist(self):
+ """
+ Create necessary directories if they don't exist.
+ """
+ # Create all necessary directories
+ directories = [
+ self.data_path,
+ ]
+
+ for directory in directories:
+ if isinstance(directory, Path) and not directory.exists():
+ directory.mkdir(parents=True, exist_ok=True)
+
+
+
+class PathsConfig:
+ """
+ Configuration for paths used in the project.
+
+ This class handles loading and saving path configurations from files,
+ and provides methods for creating Paths objects.
+ """
+
+ def __init__(self, config_file: Optional[str] = None) -> None:
+ """
+ Initialize paths configuration.
+
+ Args:
+ config_file: Path to configuration file. If None, use default config.
+ """
+ self.config: dict[str, str] = {}
+ if config_file:
+ self.load_config(config_file)
+ else:
+ self.set_default_config()
+
+ def load_config(self, config_file: str) -> None:
+ """
+ Load configuration from a YAML file.
+
+ Args:
+ config_file: Path to configuration file
+
+ Raises:
+ FileNotFoundError: If config file doesn't exist
+ yaml.YAMLError: If config file is invalid YAML
+ """
+ try:
+ with open(config_file, 'r') as f:
+ self.config = yaml.safe_load(f)
+ except FileNotFoundError:
+ raise FileNotFoundError(f"Configuration file not found: {config_file}")
+ except yaml.YAMLError as e:
+ raise yaml.YAMLError(f"Invalid YAML in configuration file {config_file}: {e}")
+
+ def save_config(self, config_file: str) -> None:
+ """
+ Save configuration to a YAML file.
+
+ Args:
+ config_file: Path to save configuration file
+
+ Raises:
+ OSError: If unable to write to the file
+ """
+ with open(config_file, 'w') as f:
+ yaml.dump(self.config, f, default_flow_style=False)
+
+ def set_default_config(self) -> None:
+ """Set default configuration values."""
+ self.config = {
+ 'data_root': './data',
+ 'processed_root': './processed_data',
+ 'project_name': 'phantom',
+ }
+
+ def get_paths(self, demo_name: str, robot_name: str = "franka") -> Paths:
+ """
+ Get Paths object for a specific demo.
+
+ Args:
+ demo_name: Name of the demo
+ robot_name: Name of the robot
+
+ Returns:
+ Paths object for the demo
+ """
+ data_path = os.path.join(self.config['data_root'], demo_name)
+
+ return Paths(
+ data_path=Path(data_path),
+ robot_name=robot_name
+ )
+
+ def get_all_demo_paths(self) -> List[str]:
+ """
+ Get list of all demo paths in data root.
+
+ Returns:
+ List of demo paths
+ """
+ data_root = self.config['data_root']
+ all_data_collection_folders = [
+ f for f in os.listdir(data_root)
+ if os.path.isdir(os.path.join(data_root, f))
+ ]
+
+ all_data_folders = [
+ os.path.join(d1, d2)
+ for d1 in os.listdir(data_root)
+ if os.path.isdir(os.path.join(data_root, d1))
+ for d2 in os.listdir(os.path.join(data_root, d1))
+ if os.path.isdir(os.path.join(data_root, d1, d2))
+ ]
+
+ return sorted(all_data_folders, key=lambda x: tuple(map(int, x.rsplit('/', 2)[-2:])))
\ No newline at end of file
diff --git a/phantom/phantom/processors/phantom_data.py b/phantom/phantom/processors/phantom_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd25a072505edc9f35e142a533619eb9e5d3e0f8
--- /dev/null
+++ b/phantom/phantom/processors/phantom_data.py
@@ -0,0 +1,340 @@
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Callable, Any
+import numpy as np
+
+hand_side_dict = {
+ 'left': 0,
+ 'right': 1,
+}
+
+class LazyLoadingMixin:
+ """Mixin to provide lazy loading functionality for cached properties."""
+
+ def _invalidate_cache(self) -> None:
+ """Invalidate all cached properties. Override in subclasses."""
+ pass
+
+ def _get_cached_property(self, cache_attr: str, compute_func: Callable[[], Any]) -> Any:
+ """Generic lazy loading for cached properties."""
+ if getattr(self, cache_attr) is None:
+ setattr(self, cache_attr, compute_func())
+ return getattr(self, cache_attr)
+
+@dataclass
+class TrainingData:
+ """Container for processing results"""
+ frame_idx: int
+ valid: bool
+ action_pos_left: np.ndarray
+ action_orixyzw_left: np.ndarray
+ action_pos_right: np.ndarray
+ action_orixyzw_right: np.ndarray
+ action_gripper_left: np.ndarray
+ action_gripper_right: np.ndarray
+ gripper_width_left: np.ndarray
+ gripper_width_right: np.ndarray
+
+ @classmethod
+ def create_empty_frame(cls, frame_idx: int) -> 'TrainingData':
+ """Create a frame with no hand detection"""
+ return cls(
+ frame_idx=frame_idx,
+ valid=False,
+ action_pos_left=np.zeros((3,)),
+ action_orixyzw_left=np.zeros((4,)),
+ action_pos_right=np.zeros((3,)),
+ action_orixyzw_right=np.zeros((4,)),
+ action_gripper_left=0,
+ action_gripper_right=0,
+ gripper_width_left=0,
+ gripper_width_right=0,
+ )
+
+class TrainingDataSequence(LazyLoadingMixin):
+ """Container for a sequence of training data"""
+ def __init__(self):
+ self.frames: List[TrainingData] = []
+ self.metadata: Dict = {}
+
+ self._frame_indices: Optional[np.ndarray] = None
+ self._valid: Optional[np.ndarray] = None
+ self._action_pos_left: Optional[np.ndarray] = None
+ self._action_orixyzw_left: Optional[np.ndarray] = None
+ self._action_pos_right: Optional[np.ndarray] = None
+ self._action_orixyzw_right: Optional[np.ndarray] = None
+ self._action_gripper_left: Optional[np.ndarray] = None
+ self._action_gripper_right: Optional[np.ndarray] = None
+ self._gripper_width_left: Optional[np.ndarray] = None
+ self._gripper_width_right: Optional[np.ndarray] = None
+
+ def add_frame(self, frame: TrainingData) -> None:
+ """Add a frame to the sequence and invalidate cached properties."""
+ self.frames.append(frame)
+ self._invalidate_cache()
+
+ def save(self, path: str) -> None:
+ """Save the sequence to disk in both frame-wise and sequence-wise formats"""
+
+ sequence_data = {
+ 'frame_indices': self.frame_indices,
+ 'valid': self.valid,
+ 'action_pos_left': self.action_pos_left,
+ 'action_orixyzw_left': self.action_orixyzw_left,
+ 'action_pos_right': self.action_pos_right,
+ 'action_orixyzw_right': self.action_orixyzw_right,
+ 'action_gripper_left': self.action_gripper_left,
+ 'action_gripper_right': self.action_gripper_right,
+ 'gripper_width_left': self.gripper_width_left,
+ 'gripper_width_right': self.gripper_width_right,
+ }
+
+ np.savez_compressed(
+ path,
+ **sequence_data
+ )
+
+ @property
+ def frame_indices(self) -> np.ndarray:
+ """Lazy loading of all frame indices"""
+ return self._get_cached_property(
+ '_frame_indices',
+ lambda: np.arange(len(self.frames))
+ )
+
+ @property
+ def valid(self) -> np.ndarray:
+ """Lazy loading of all valid flags"""
+ return self._get_cached_property(
+ '_valid',
+ lambda: np.stack([f.valid for f in self.frames])
+ )
+
+ @property
+ def action_pos_left(self) -> np.ndarray:
+ """Lazy loading of all action positions"""
+ return self._get_cached_property(
+ '_action_pos_left',
+ lambda: np.stack([f.action_pos_left for f in self.frames])
+ )
+
+ @property
+ def action_orixyzw_left(self) -> np.ndarray:
+ """Lazy loading of all action orientations"""
+ return self._get_cached_property(
+ '_action_orixyzw_left',
+ lambda: np.stack([f.action_orixyzw_left for f in self.frames])
+ )
+
+ @property
+ def action_pos_right(self) -> np.ndarray:
+ """Lazy loading of all action positions"""
+ return self._get_cached_property(
+ '_action_pos_right',
+ lambda: np.stack([f.action_pos_right for f in self.frames])
+ )
+
+ @property
+ def action_orixyzw_right(self) -> np.ndarray:
+ """Lazy loading of all action orientations"""
+ return self._get_cached_property(
+ '_action_orixyzw_right',
+ lambda: np.stack([f.action_orixyzw_right for f in self.frames])
+ )
+
+ @property
+ def action_gripper_left(self) -> np.ndarray:
+ """Lazy loading of all action gripper distances"""
+ return self._get_cached_property(
+ '_action_gripper_left',
+ lambda: np.stack([f.action_gripper_left for f in self.frames])
+ )
+
+ @property
+ def action_gripper_right(self) -> np.ndarray:
+ """Lazy loading of all action gripper distances"""
+ return self._get_cached_property(
+ '_action_gripper_right',
+ lambda: np.stack([f.action_gripper_right for f in self.frames])
+ )
+
+ @property
+ def gripper_width_left(self) -> np.ndarray:
+ """Lazy loading of all gripper widths"""
+ return self._get_cached_property(
+ '_gripper_width_left',
+ lambda: np.stack([f.gripper_width_left for f in self.frames])
+ )
+
+ @property
+ def gripper_width_right(self) -> np.ndarray:
+ """Lazy loading of all gripper widths"""
+ return self._get_cached_property(
+ '_gripper_width_right',
+ lambda: np.stack([f.gripper_width_right for f in self.frames])
+ )
+
+ def _invalidate_cache(self):
+ """Invalidate all cached properties."""
+ self._frame_indices = None
+ self._valid = None
+ self._action_pos_left = None
+ self._action_orixyzw_left = None
+ self._action_pos_right = None
+ self._action_orixyzw_right = None
+ self._action_gripper_left = None
+ self._action_gripper_right = None
+ self._gripper_width_left = None
+ self._gripper_width_right = None
+
+ @classmethod
+ def load(cls, path: str) -> 'TrainingDataSequence':
+ """Load a sequence from disk"""
+ data = np.load(path, allow_pickle=True)
+ sequence = cls()
+
+ sequence._frame_indices = data['frame_indices']
+ sequence._valid = data['valid']
+ sequence._action_pos_left = data['action_pos_left']
+ sequence._action_orixyzw_left = data['action_orixyzw_left']
+ sequence._action_pos_right = data['action_pos_right']
+ sequence._action_orixyzw_right = data['action_orixyzw_right']
+ sequence._action_gripper_left = data['action_gripper_left']
+ sequence._action_gripper_right = data['action_gripper_right']
+ sequence._gripper_width_left = data['gripper_width_left']
+ sequence._gripper_width_right = data['gripper_width_right']
+
+ return sequence
+
+@dataclass
+class HandFrame:
+ """Data structure for a single frame of hand data"""
+ frame_idx: int
+ hand_detected: bool
+ img_rgb: np.ndarray
+ img_hamer: np.ndarray
+ kpts_2d: np.ndarray # shape: (N, 2)
+ kpts_3d: np.ndarray # shape: (N, 3)
+
+ @classmethod
+ def create_empty_frame(cls, frame_idx: int, img_rgb: np.ndarray) -> 'HandFrame':
+ """Create a frame with no hand detection"""
+ return cls(
+ frame_idx=frame_idx,
+ hand_detected=False,
+ img_rgb=img_rgb,
+ img_hamer=np.zeros_like(img_rgb),
+ kpts_2d=np.zeros((21, 2)),
+ kpts_3d=np.zeros((21, 3)),
+ )
+
+class HandSequence(LazyLoadingMixin):
+ """Container for a sequence of hand data"""
+ def __init__(self):
+ self.frames: List[HandFrame] = []
+ self.metadata: Dict = {}
+
+ self._frame_indices: Optional[np.ndarray] = None
+ self._hand_detected: Optional[np.ndarray] = None
+ self._img_rgb: Optional[np.ndarray] = None
+ self._img_hamer: Optional[np.ndarray] = None
+ self._kpts_2d: Optional[np.ndarray] = None
+ self._kpts_3d: Optional[np.ndarray] = None
+
+ def add_frame(self, frame: HandFrame) -> None:
+ """Add a frame to the sequence and invalidate cached properties."""
+ self.frames.append(frame)
+ self._invalidate_cache()
+
+ def get_frame(self, frame_idx: int) -> HandFrame:
+ """Get a frame by index."""
+ return self.frames[frame_idx]
+
+ def modify_frame(self, frame_idx: int, frame: HandFrame) -> None:
+ """Modify a frame at the given index and invalidate cached properties."""
+ self.frames[frame_idx] = frame
+ self._invalidate_cache()
+
+ def save(self, path: str) -> None:
+ """Save the sequence to disk in both frame-wise and sequence-wise formats"""
+ sequence_data = {
+ 'hand_detected': self.hand_detected,
+ 'kpts_2d': self.kpts_2d,
+ 'kpts_3d': self.kpts_3d,
+ 'frame_indices': self.frame_indices,
+ }
+
+ np.savez_compressed(
+ path,
+ **sequence_data
+ )
+
+ @property
+ def frame_indices(self) -> np.ndarray:
+ """Lazy loading of all frame indices"""
+ return self._get_cached_property(
+ '_frame_indices',
+ lambda: np.arange(len(self.frames))
+ )
+
+ @property
+ def hand_detected(self) -> np.ndarray:
+ """Lazy loading of all hand detection flags"""
+ return self._get_cached_property(
+ '_hand_detected',
+ lambda: np.stack([f.hand_detected for f in self.frames])
+ )
+
+ @property
+ def imgs_rgb(self) -> np.ndarray:
+ """Lazy loading of all RGB images"""
+ return self._get_cached_property(
+ '_img_rgb',
+ lambda: np.stack([f.img_rgb for f in self.frames])
+ )
+
+ @property
+ def imgs_hamer(self) -> np.ndarray:
+ """Lazy loading of all HAMER images"""
+ return self._get_cached_property(
+ '_img_hamer',
+ lambda: np.stack([f.img_hamer for f in self.frames])
+ )
+
+ @property
+ def kpts_2d(self) -> np.ndarray:
+ """Lazy loading of all 2D keypoints"""
+ return self._get_cached_property(
+ '_kpts_2d',
+ lambda: np.stack([f.kpts_2d for f in self.frames])
+ )
+
+ @property
+ def kpts_3d(self) -> np.ndarray:
+ """Lazy loading of all 3D keypoints"""
+ return self._get_cached_property(
+ '_kpts_3d',
+ lambda: np.stack([f.kpts_3d for f in self.frames])
+ )
+
+ @classmethod
+ def load(cls, path: str) -> 'HandSequence':
+ """Load a sequence from disk"""
+ data = np.load(path, allow_pickle=True)
+ sequence = cls()
+
+ # Load pre-computed sequence-wise data
+ sequence._frame_indices = data['frame_indices']
+ sequence._hand_detected = data['hand_detected']
+ sequence._kpts_2d = data['kpts_2d']
+ sequence._kpts_3d = data['kpts_3d']
+
+ return sequence
+
+ def _invalidate_cache(self):
+ """Invalidate all cached properties."""
+ self._frame_indices = None
+ self._hand_detected = None
+ self._img_rgb = None
+ self._img_hamer = None
+ self._kpts_2d = None
+ self._kpts_3d = None
\ No newline at end of file
diff --git a/phantom/phantom/processors/robotinpaint_processor.py b/phantom/phantom/processors/robotinpaint_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd8b1e53191d19dde50c21f150e79d3358573dd1
--- /dev/null
+++ b/phantom/phantom/processors/robotinpaint_processor.py
@@ -0,0 +1,785 @@
+"""
+Robot Inpainting Processor Module
+
+This module uses MuJoCo to render robot models and overlay them onto human demonstration videos.
+
+Processing Pipeline:
+1. Load smoothed robot trajectories from previous processing stages
+2. Initialize MuJoCo robot simulation with calibrated camera parameters
+3. For each frame:
+ - Move simulated robot to target pose from human demonstration
+ - Render robot from calibrated camera viewpoint
+ - Apply depth-based occlusion handling (Optional)
+ - Create robot overlay on human demonstration video
+4. Generate training data with robot state annotations
+5. Save robot-inpainted videos and training data
+"""
+
+import os
+import pdb
+import numpy as np
+import cv2
+from tqdm import tqdm
+import mediapy as media
+from scipy.spatial.transform import Rotation
+from typing import Tuple, Dict, List, Optional, Any, Union
+import logging
+from dataclasses import dataclass
+
+from phantom.processors.phantom_data import TrainingData, TrainingDataSequence, HandSequence
+from phantom.processors.base_processor import BaseProcessor
+from phantom.twin_bimanual_robot import TwinBimanualRobot, MujocoCameraParams
+from phantom.twin_robot import TwinRobot
+from phantom.processors.paths import Paths
+
+
+logger = logging.getLogger(__name__)
+
+@dataclass
+class RobotState:
+ """
+ Container for robot state data including pose and gripper configuration.
+
+ Attributes:
+ pos: 3D position coordinates in world frame
+ ori_xyzw: Quaternion orientation in XYZW format (scalar-last)
+ gripper_pos: Gripper opening distance or action value
+ """
+ pos: np.ndarray
+ ori_xyzw: np.ndarray
+ gripper_pos: float
+
+class RobotInpaintProcessor(BaseProcessor):
+ """
+ Uses mujoco to overlay robot on human inpainted images.
+ """
+ # Processing constants for quality control and output formatting
+ TRACKING_ERROR_THRESHOLD = 0.05 # Maximum tracking error in meters
+ DEFAULT_FPS = 15 # Standard frame rate for output videos
+ DEFAULT_CODEC = "ffv1" # Lossless codec for high-quality output
+
+ def __init__(self, args: Any) -> None:
+ """
+ Initialize the robot inpainting processor with simulation parameters.
+
+ Args:
+ args: Command line arguments containing robot configuration,
+ camera parameters, and processing options
+ """
+ super().__init__(args)
+ self.use_depth = self.depth_for_overlay
+ self._initialize_robot()
+
+ def _initialize_robot(self) -> None:
+ """
+ Initialize the twin robot simulation with calibrated camera parameters.
+ """
+ # Generate MuJoCo camera parameters from real-world calibration
+ camera_params = self._get_mujoco_camera_params()
+ img_w, img_h = self._get_image_dimensions()
+
+ # Initialize appropriate robot configuration
+ if self.bimanual_setup == "single_arm":
+ self.twin_robot = TwinRobot(
+ self.robot,
+ self.gripper,
+ camera_params,
+ camera_height=img_h,
+ camera_width=img_w,
+ render=self.render,
+ n_steps_short=3,
+ n_steps_long=75,
+ debug_cameras=self.debug_cameras,
+ square=self.square,
+ )
+ else:
+ self.twin_robot = TwinBimanualRobot(
+ self.robot,
+ self.gripper,
+ self.bimanual_setup,
+ camera_params,
+ camera_height=img_h,
+ camera_width=img_w,
+ render=self.render,
+ n_steps_short=10,
+ n_steps_long=75,
+ debug_cameras=self.debug_cameras,
+ epic=self.epic,
+ joint_controller=False, # Use operational-space control
+ )
+
+ def __del__(self):
+ """Clean up robot simulation resources."""
+ if hasattr(self, 'twin_robot'):
+ self.twin_robot.close()
+
+ def process_one_demo(self, data_sub_folder: str) -> None:
+ """
+ Process a single demonstration to create robot-inpainted visualization.
+
+ Args:
+ data_sub_folder: Path to demonstration data folder containing
+ smoothed trajectories and original video data
+ """
+ save_folder = self.get_save_folder(data_sub_folder)
+ if self._should_skip_processing(save_folder):
+ return
+ paths = self.get_paths(save_folder)
+
+ # Reinitialize robot simulation for each demo to ensure clean state
+ self.__del__()
+ self._initialize_robot()
+
+ # Load and prepare demonstration data
+ data = self._load_data(paths)
+ images = self._load_images(paths, data["union_indices"])
+ gripper_actions, gripper_widths = self._process_gripper_widths(paths, data)
+
+ # Process all frames to generate robot overlays and training data
+ sequence, img_overlay, img_birdview = self._process_frames(images, data, gripper_actions, gripper_widths)
+
+ # Save comprehensive results
+ self._save_results(paths, sequence, img_overlay, img_birdview)
+
+ def _process_frames(self, images: Dict[str, np.ndarray], data: Dict[str, np.ndarray],
+ gripper_actions: Dict[str, np.ndarray], gripper_widths: Dict[str, np.ndarray]) -> Tuple[TrainingDataSequence, List[np.ndarray], Optional[List[np.ndarray]]]:
+ """
+ Process each frame to generate robot overlays and training data.
+
+ Args:
+ images: Dictionary containing human demonstration images and masks
+ data: Robot trajectory data (positions and orientations)
+ gripper_actions: Processed gripper action commands
+ gripper_widths: Gripper opening distances
+
+ Returns:
+ Tuple containing:
+ - TrainingDataSequence with robot state annotations
+ - List of robot overlay images
+ - Optional list of bird's eye view images (if debug cameras enabled)
+ """
+ sequence = TrainingDataSequence()
+ img_overlay = []
+ img_birdview = None
+ if "birdview" in self.debug_cameras:
+ img_birdview = []
+
+ for idx in tqdm(range(len(images['human_imgs'])), desc="Processing frames"):
+ # Extract robot states for current frame
+ left_state = self._get_robot_state(
+ data['ee_pts_left'][idx],
+ data['ee_oris_left'][idx],
+ gripper_widths['left'][idx]
+ )
+ right_state = self._get_robot_state(
+ data['ee_pts_right'][idx],
+ data['ee_oris_right'][idx],
+ gripper_widths['right'][idx]
+ )
+
+ # Process individual frame with robot simulation
+ frame_results = self._process_single_frame(
+ images, left_state, right_state, idx
+ )
+
+ # Handle failed processing (tracking errors, simulation issues)
+ if frame_results is None:
+ print(f"sdfsdfsTracking error too large at frame {idx}, skipping")
+ sequence.add_frame(TrainingData.create_empty_frame(
+ frame_idx=idx,
+ ))
+ img_overlay.append(np.zeros_like(images['human_imgs'][idx]))
+ if "birdview" in self.debug_cameras:
+ img_birdview.append(np.zeros_like(images['human_imgs'][idx]))
+ else:
+ # Create comprehensive training data annotation
+ sequence.add_frame(TrainingData(
+ frame_idx=idx,
+ valid=True,
+ action_pos_left=left_state.pos,
+ action_orixyzw_left=left_state.ori_xyzw,
+ action_pos_right=right_state.pos,
+ action_orixyzw_right=right_state.ori_xyzw,
+ action_gripper_left=gripper_actions['left'][idx],
+ action_gripper_right=gripper_actions['right'][idx],
+ gripper_width_left=gripper_widths['left'][idx],
+ gripper_width_right=gripper_widths['right'][idx],
+ ))
+ img_overlay.append(frame_results['rgb_robot_overlay'])
+ if "birdview" in self.debug_cameras:
+ img_birdview.append(frame_results['birdview_img'])
+ return sequence, img_overlay, img_birdview
+
+
+ def _process_single_frame(self, images: Dict[str, np.ndarray],
+ left_state: RobotState,
+ right_state: RobotState,
+ idx: int) -> Optional[Dict[str, np.ndarray]]:
+ """
+ Process a single frame to generate robot overlay and validate tracking.
+
+ Args:
+ images: Dictionary containing human images and segmentation data
+ left_state: Target state for left robot arm
+ right_state: Target state for right robot arm
+ idx: Frame index for initialization and logging
+
+ Returns:
+ Dictionary containing rendered robot overlay and debug camera views,
+ or None if tracking error exceeds threshold
+ """
+ # Prepare robot target state based on configuration
+ if self.bimanual_setup == "single_arm":
+ if self.target_hand == "left":
+ target_state = {
+ "pos": left_state.pos,
+ "ori_xyzw": left_state.ori_xyzw,
+ "gripper_pos": left_state.gripper_pos,
+ }
+ else:
+ target_state = {
+ "pos": right_state.pos,
+ "ori_xyzw": right_state.ori_xyzw,
+ "gripper_pos": right_state.gripper_pos,
+ }
+ else:
+ # Bimanual configuration requires coordinated control
+ target_state = {
+ "pos": [right_state.pos, left_state.pos],
+ "ori_xyzw": [right_state.ori_xyzw, left_state.ori_xyzw],
+ "gripper_pos": [right_state.gripper_pos, left_state.gripper_pos],
+ }
+
+ # Move robot to target state and get simulation results
+ robot_results = self.twin_robot.move_to_target_state(
+ target_state, init=(idx == 0) # Initialize on first frame
+ )
+
+ # Validate tracking accuracy to ensure quality
+ if self.bimanual_setup == "single_arm":
+ if robot_results['pos_err'] > self.TRACKING_ERROR_THRESHOLD:
+ print(f"Tracking error too large at frame {idx}, skipping", robot_results['pos_err'])
+ logger.warning(f"Tracking error too large at frame {idx}, skipping")
+ return None
+ else:
+ if robot_results['left_pos_err'] > self.TRACKING_ERROR_THRESHOLD or robot_results['right_pos_err'] > self.TRACKING_ERROR_THRESHOLD:
+ logger.warning(f"Tracking error too large at frame {idx}, skipping")
+ return None
+
+ # Generate robot overlay using appropriate method
+ if self.use_depth:
+ rgb_robot_overlay = self._process_robot_overlay_with_depth(
+ images['human_imgs'][idx],
+ images['human_masks'][idx],
+ images['imgs_depth'][idx],
+ robot_results
+ )
+ else:
+ rgb_robot_overlay = self._process_robot_overlay(
+ images['human_imgs'][idx], robot_results
+ )
+
+ # Prepare output with main overlay and debug camera views
+ output = {
+ 'rgb_robot_overlay': rgb_robot_overlay,
+ }
+
+ # Add debug camera views if requested
+ for cam in self.debug_cameras:
+ output[f"{cam}_img"] = (robot_results[f"{cam}_img"] * 255).astype(np.uint8)
+
+ return output
+
+ def _should_skip_processing(self, save_folder: str) -> bool:
+ """
+ Check if processing should be skipped due to existing output files.
+
+ Args:
+ save_folder: Directory where output files would be saved
+
+ Returns:
+ True if processing should be skipped, False otherwise
+ """
+ if self.skip_existing:
+ try:
+ with os.scandir(save_folder) as it:
+ existing_files = {entry.name for entry in it if entry.is_file()}
+ if str("video_overlay"+f"_{self.robot}_{self.bimanual_setup}.mkv") in existing_files:
+ print(f"Skipping existing demo {save_folder}")
+ return True
+ except FileNotFoundError:
+ return False
+ return False
+
+ def _load_data(self, paths: Paths) -> Dict[str, np.ndarray]:
+ """
+ Load robot trajectory data from smoothed action files.
+
+ Args:
+ paths: Paths object containing file locations
+
+ Returns:
+ Dictionary containing robot trajectory data and frame indices
+ """
+ if self.bimanual_setup == "single_arm":
+ # Get paths based on target hand for single-arm operation
+ smoothed_base = getattr(paths, f"smoothed_actions_{self.target_hand}")
+ actions_base = getattr(paths, f"actions_{self.target_hand}")
+ smoothed_actions_path = str(smoothed_base).replace(".npz", f"_{self.bimanual_setup}.npz")
+ actions_path = str(actions_base).replace(".npz", f"_{self.bimanual_setup}.npz")
+
+ # Load actual trajectory data for target hand
+ ee_pts = np.load(smoothed_actions_path)["ee_pts"]
+ ee_oris = np.load(smoothed_actions_path)["ee_oris"]
+
+ # Create dummy data for non-target hand
+ dummy_pts = np.zeros((len(ee_pts), 3))
+ dummy_oris = np.eye(3)[None, :, :].repeat(len(ee_oris), axis=0)
+
+ # Create data dictionary with target hand data and dummy data for other hand
+ other_hand = "right" if self.target_hand == "left" else "left"
+ return {
+ f'ee_pts_{self.target_hand}': ee_pts,
+ f'ee_oris_{self.target_hand}': ee_oris,
+ f'ee_pts_{other_hand}': dummy_pts,
+ f'ee_oris_{other_hand}': dummy_oris,
+ 'union_indices': np.load(actions_path, allow_pickle=True)["union_indices"]
+ }
+
+ # Load bimanual trajectory data
+ smoothed_actions_left_path = str(paths.smoothed_actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ smoothed_actions_right_path = str(paths.smoothed_actions_right).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ actions_left_path = str(paths.actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ return {
+ 'ee_pts_left': np.load(smoothed_actions_left_path)["ee_pts"],
+ 'ee_oris_left': np.load(smoothed_actions_left_path)["ee_oris"],
+ 'ee_pts_right': np.load(smoothed_actions_right_path)["ee_pts"],
+ 'ee_oris_right': np.load(smoothed_actions_right_path)["ee_oris"],
+ 'union_indices': np.load(actions_left_path, allow_pickle=True)["union_indices"]
+ }
+
+ def _load_images(self, paths: Paths, union_indices: np.ndarray) -> Dict[str, np.ndarray]:
+ """
+ Load and index human demonstration images and associated data.
+
+ Args:
+ paths: Paths object containing image file locations
+ union_indices: Frame indices to extract from full video sequences
+
+ Returns:
+ Dictionary containing indexed human images, masks, and depth data
+ """
+ return {
+ 'human_masks': np.load(paths.masks_arm)[union_indices],
+ 'human_imgs': np.array(media.read_video(paths.video_human_inpaint))[union_indices],
+ 'imgs_depth': np.load(paths.depth)[union_indices] if self.use_depth else None
+ }
+
+ def _process_gripper_widths(self, paths: Paths, data: Dict[str, np.ndarray]) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
+ """
+ Process gripper distance data into robot action commands.
+
+ Args:
+ paths: Paths object containing smoothed action file locations
+ data: Dictionary containing trajectory data and frame indices
+
+ Returns:
+ Tuple containing:
+ - Dictionary of gripper action commands for each hand
+ - Dictionary of gripper width values for each hand
+ """
+ if self.bimanual_setup == "single_arm":
+ # Get the appropriate smoothed actions path based on target hand
+ base_path = getattr(paths, f"smoothed_actions_{self.target_hand}")
+ smoothed_actions_path = str(base_path).replace(".npz", f"_{self.bimanual_setup}.npz")
+
+ # Compute gripper actions and widths from smoothed data
+ actions, widths = self._compute_gripper_actions(
+ np.load(smoothed_actions_path)["ee_widths"]
+ )
+
+ # Create return dictionaries with actions for target hand, zeros for the other
+ num_indices = len(data['union_indices'])
+ other_hand = "right" if self.target_hand == "left" else "left"
+
+ return (
+ {self.target_hand: actions, other_hand: np.zeros(num_indices)},
+ {self.target_hand: widths, other_hand: np.zeros(num_indices)}
+ )
+
+ # Process bimanual gripper data
+ smoothed_actions_left_path = str(paths.smoothed_actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ smoothed_actions_right_path = str(paths.smoothed_actions_right).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ left_actions, left_widths = self._compute_gripper_actions(
+ np.load(smoothed_actions_left_path)["ee_widths"]
+ )
+ right_actions, right_widths = self._compute_gripper_actions(
+ np.load(smoothed_actions_right_path)["ee_widths"]
+ )
+ return {'left': left_actions, 'right': right_actions}, {'left': left_widths, 'right': right_widths}
+
+
+ def _compute_gripper_actions(self, list_gripper_dist: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Convert continuous gripper distances to discrete robot gripper actions.
+ Args:
+ list_gripper_dist: Array of gripper distances throughout trajectory
+
+ Returns:
+ Tuple containing:
+ - Gripper action commands (0 for grasp, distance for open)
+ - Processed gripper width values
+ """
+ try:
+ # Analyze gripper distance range and determine grasp threshold
+ min_val, max_val = np.min(list_gripper_dist), np.max(list_gripper_dist)
+ thresh = min_val + 0.2 * (max_val - min_val) # 20% above minimum
+
+ # Classify gripper states: 0 = closed/grasping, 1 = open
+ gripper_state = np.array([0 if dist < thresh else 1 for dist in list_gripper_dist])
+
+ # Find range of grasping action
+ min_idx_pos = np.where(gripper_state == 0)[0][0]
+ max_idx_pos = np.where(gripper_state == 0)[0][-1]
+
+ # Generate gripper action commands
+ list_gripper_actions = []
+ for idx in range(len(list_gripper_dist)):
+ if min_idx_pos <= idx <= max_idx_pos:
+ # During grasping phase: use grasp command (0) and limit distance
+ list_gripper_actions.append(0)
+ list_gripper_dist[idx] = np.min([list_gripper_dist[idx], thresh])
+ else:
+ # Outside grasping phase: use distance as action command
+ list_gripper_actions.append(list_gripper_dist[idx])
+ except:
+ # Fallback: use distances directly if processing fails
+ list_gripper_actions = list_gripper_dist.tolist()
+
+ return np.array(list_gripper_actions), list_gripper_dist
+
+ def _get_robot_state(self, ee_pt: np.ndarray, ori_matrix: np.ndarray, gripper_dist: float) -> RobotState:
+ """
+ Convert trajectory data to robot state representation.
+
+ Args:
+ ee_pt: End-effector position in 3D space
+ ori_matrix: 3x3 rotation matrix for end-effector orientation
+ gripper_dist: Gripper opening distance
+
+ Returns:
+ RobotState object containing pose and gripper information
+ """
+ # Convert rotation matrix to quaternion (XYZW format for robot control)
+ ori_xyzw = Rotation.from_matrix(ori_matrix).as_quat(scalar_first=False)
+ robot_state = RobotState(pos=ee_pt, ori_xyzw=ori_xyzw, gripper_pos=gripper_dist)
+ return robot_state
+
+ def _process_robot_overlay(self, img: np.ndarray, robot_results: Dict[str, Any]) -> np.ndarray:
+ """
+ Create robot overlay on human image using segmentation masks.
+
+ Args:
+ img: Original human demonstration image
+ robot_results: Dictionary containing robot rendering results
+
+ Returns:
+ Image with robot overlay applied
+ """
+ # Extract robot rendering and segmentation data
+ rgb_img_sim = (robot_results['rgb_img'] * 255).astype(np.uint8)
+ H, W = rgb_img_sim.shape[:2]
+
+ # Resize robot rendering and masks to match output resolution
+ if self.square:
+ rgb_img_sim = cv2.resize(rgb_img_sim, (self.output_resolution, self.output_resolution))
+ robot_mask = cv2.resize(robot_results['robot_mask'], (self.output_resolution, self.output_resolution))
+ robot_mask[robot_mask > 0] = 1
+ gripper_mask = cv2.resize(robot_results['gripper_mask'], (self.output_resolution, self.output_resolution))
+ gripper_mask[gripper_mask > 0] = 1
+ else:
+ rgb_img_sim = cv2.resize(rgb_img_sim, (int(W/H*self.output_resolution), self.output_resolution))
+ robot_mask = cv2.resize(robot_results['robot_mask'], (int(W/H*self.output_resolution), self.output_resolution))
+ robot_mask[robot_mask > 0] = 1
+ gripper_mask = cv2.resize(robot_results['gripper_mask'], (int(W/H*self.output_resolution), self.output_resolution))
+ gripper_mask[gripper_mask > 0] = 1
+
+ # Create overlay by compositing robot over human image
+ img_robot_overlay = img.copy()
+ overlay_mask = (robot_mask == 1) | (gripper_mask == 1)
+ img_robot_overlay[overlay_mask] = rgb_img_sim[overlay_mask]
+
+ return img_robot_overlay
+
+ def _process_robot_overlay_with_depth(self, img: np.ndarray, hand_mask: np.ndarray,
+ img_depth: np.ndarray, robot_results: Dict[str, Any]) -> np.ndarray:
+ """
+ Create depth-aware robot overlay with realistic occlusion handling.
+
+ Args:
+ img: Original human demonstration image
+ hand_mask: Segmentation mask of human hand regions
+ img_depth: Depth image corresponding to the demonstration
+ robot_results: Dictionary containing robot rendering and depth results
+
+ Returns:
+ Image with depth-aware robot overlay applied
+ """
+ # Extract robot rendering and depth data
+ robot_mask = robot_results['robot_mask']
+ gripper_mask = robot_results['gripper_mask']
+ rgb_img_sim = robot_results['rgb_img']
+ depth_img_sim = np.squeeze(robot_results['depth_img'])
+ H, W = rgb_img_sim.shape[:2]
+
+ # Create masked depth images for occlusion analysis
+ depth_sim_masked = self._create_masked_depth(depth_img_sim, robot_mask, gripper_mask)
+ depth_masked = self._create_masked_depth(img_depth, robot_mask, gripper_mask)
+
+ # Process hand mask for improved occlusion handling
+ hand_mask = self._dilate_mask(hand_mask.astype(np.uint8))
+
+ # Create overlay mask using depth-based occlusion
+ img_robot_overlay = img.copy()
+ overlay_mask = self._create_overlay_mask(
+ robot_mask, gripper_mask, depth_masked, depth_sim_masked, hand_mask
+ )
+
+ # Convert and resize robot rendering
+ rgb_img_sim = (rgb_img_sim * 255).astype(np.uint8)
+
+ if self.square:
+ resize_shape = (self.output_resolution, self.output_resolution)
+ else:
+ resize_shape = (int(W/H*self.output_resolution), self.output_resolution)
+
+ # Apply final overlay with depth-aware occlusion
+ rgb_img_sim = cv2.resize(rgb_img_sim, resize_shape)
+ overlay_mask = cv2.resize(overlay_mask.astype(np.uint8), resize_shape)
+ overlay_mask[overlay_mask > 0] = 1
+ overlay_mask = overlay_mask.astype(bool)
+
+ img_robot_overlay[overlay_mask] = rgb_img_sim[overlay_mask]
+
+ return img_robot_overlay
+
+ def _create_masked_depth(self, depth_img: np.ndarray, robot_mask: np.ndarray,
+ gripper_mask: np.ndarray) -> np.ndarray:
+ """
+ Create depth image masked to robot regions for occlusion analysis.
+
+ Args:
+ depth_img: Input depth image
+ robot_mask: Binary mask indicating robot regions
+ gripper_mask: Binary mask indicating gripper regions
+
+ Returns:
+ Depth image with values only in robot/gripper regions
+ """
+ masked_img = np.zeros_like(depth_img)
+ mask = (robot_mask == 1) | (gripper_mask == 1)
+ masked_img[mask] = depth_img[mask]
+ return masked_img
+
+ def _dilate_mask(self, mask: np.ndarray) -> np.ndarray:
+ """
+ Apply morphological dilation to expand mask boundaries.
+
+ Args:
+ mask: Binary mask to dilate
+
+ Returns:
+ Dilated binary mask
+ """
+ kernel = np.ones((5, 5), np.uint8)
+ return cv2.dilate(mask, kernel, iterations=1)
+
+ def _create_overlay_mask(self, robot_mask: np.ndarray, gripper_mask: np.ndarray,
+ depth_masked: np.ndarray, depth_sim_masked: np.ndarray,
+ hand_mask: np.ndarray) -> np.ndarray:
+ """
+ Create sophisticated overlay mask using depth-based occlusion reasoning.
+
+ Args:
+ robot_mask: Binary mask for robot body regions
+ gripper_mask: Binary mask for robot gripper regions
+ depth_masked: Real depth image masked to robot regions
+ depth_sim_masked: Simulated robot depth masked to robot regions
+ hand_mask: Binary mask for human hand regions
+
+ Returns:
+ Binary mask indicating where robot overlay should be applied
+ """
+ # Start with basic robot visibility mask
+ overlay_mask = (robot_mask == 1) | (gripper_mask == 1)
+
+ # Apply depth-based occlusion: hide robot when it's behind real objects
+ # and not in hand regions (where occlusion handling is more complex)
+ overlay_mask[(depth_masked < depth_sim_masked) & (hand_mask == 0)] = 0
+
+ return overlay_mask
+
+ def _save_results(self, paths: Paths, sequence: TrainingDataSequence, img_overlay: List[np.ndarray],
+ img_birdview: Optional[List[np.ndarray]] = None) -> None:
+ """
+ Save comprehensive robot inpainting results to disk.
+
+ Args:
+ paths: Paths object containing output file locations
+ sequence: Training data sequence with robot state annotations
+ img_overlay: List of robot overlay images
+ img_birdview: Optional list of bird's eye view images for analysis
+ """
+ # Create output directory
+ os.makedirs(paths.inpaint_processor, exist_ok=True)
+
+ if len(img_overlay) == 0:
+ print("No robot inpainted images, skipping")
+ return
+
+ # Save main robot-inpainted video
+ video_path = str(paths.video_overlay).split(".mkv")[0] + f"_{self.robot}_{self.bimanual_setup}.mkv"
+ self._save_video(video_path, img_overlay)
+
+ # Save bird's eye view video for analysis and debugging
+ if img_birdview is not None:
+ birdview_path = str(paths.video_birdview).split(".mkv")[0] + f"_{self.robot}_{self.bimanual_setup}.mkv"
+ self._save_video(birdview_path, np.array(img_birdview))
+
+ # Save comprehensive training data with robot state annotations
+ training_data_path = str(paths.training_data).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ sequence.save(training_data_path)
+
+ def _save_video(self, path: str, frames: List[np.ndarray]) -> None:
+ """
+ Save video with consistent encoding parameters.
+
+ Args:
+ path: Output video file path
+ frames: List of video frames to save
+ """
+ media.write_video(
+ path,
+ frames,
+ fps=self.DEFAULT_FPS,
+ codec=self.DEFAULT_CODEC
+ )
+
+ def _get_mujoco_camera_params(self) -> MujocoCameraParams:
+ """
+ Generate MuJoCo camera parameters from real-world camera calibration.
+
+ Returns:
+ MujocoCameraParams object with calibrated camera settings
+ """
+ # Extract real-world camera extrinsics and convert to MuJoCo format
+ extrinsics = self.extrinsics[0]
+ camera_ori_wxyz = self._convert_real_camera_ori_to_mujoco(
+ np.array(extrinsics["camera_base_ori"])
+ )
+
+ # Calculate image dimensions and camera intrinsics
+ img_w, img_h = self._get_image_dimensions()
+ offset = self._calculate_image_offset(img_w, img_h)
+ fx, fy, cx, cy = self._get_camera_intrinsics(offset)
+ sensor_width, sensor_height = self._calculate_sensor_size(img_w, img_h, fx, fy)
+
+ # Select appropriate camera name based on dataset
+ if self.epic:
+ camera_name = "zed"
+ else:
+ camera_name = "frontview"
+
+ return MujocoCameraParams(
+ name=camera_name,
+ pos=extrinsics["camera_base_pos"],
+ ori_wxyz=camera_ori_wxyz,
+ fov=self.intrinsics_dict["v_fov"],
+ resolution=(img_h, img_w),
+ sensorsize=np.array([sensor_width, sensor_height]),
+ principalpixel=np.array([img_w/2-cx, cy-img_h/2]),
+ focalpixel=np.array([fx, fy])
+ )
+
+ def _get_image_dimensions(self) -> Tuple[int, int]:
+ """
+ Calculate image dimensions based on input resolution configuration.
+
+ Returns:
+ Tuple of (width, height) in pixels
+ """
+ # Epic
+ if self.input_resolution == 256:
+ img_w = 456
+ # Phantom paper
+ elif self.input_resolution == 1080:
+ img_w = self.input_resolution * 16 // 9
+ img_h = self.input_resolution
+ return img_w, img_h
+
+ def _calculate_image_offset(self, img_w: int, img_h: int) -> int:
+ """
+ Calculate horizontal image offset for square aspect ratio processing.
+
+ Args:
+ img_w: Image width in pixels
+ img_h: Image height in pixels
+
+ Returns:
+ Horizontal offset in pixels
+ """
+ if self.square:
+ offset = (img_w - img_h) // 2
+ else:
+ offset = 0
+ return offset
+
+ def _get_camera_intrinsics(self, offset: int) -> Tuple[float, float, float, float]:
+ """
+ Extract camera intrinsic parameters with offset correction.
+
+ Args:
+ offset: Horizontal offset for principal point adjustment
+
+ Returns:
+ Tuple of (fx, fy, cx, cy) camera intrinsic parameters
+ """
+ return self.intrinsics_dict["fx"], self.intrinsics_dict["fy"], self.intrinsics_dict["cx"]+offset, self.intrinsics_dict["cy"]
+
+ def _calculate_sensor_size(self, img_w: int, img_h: int, fx: float, fy: float) -> Tuple[float, float]:
+ """
+ Calculate physical sensor dimensions from image resolution and focal length.
+
+ Args:
+ img_w: Image width in pixels
+ img_h: Image height in pixels
+ fx: Focal length in x direction (pixels)
+ fy: Focal length in y direction (pixels)
+
+ Returns:
+ Tuple of (sensor_width, sensor_height) in meters
+ """
+ sensor_width = img_w / fy / 1000
+ sensor_height = img_h / fx / 1000
+ return sensor_width, sensor_height
+
+ @staticmethod
+ def _convert_real_camera_ori_to_mujoco(camera_ori_matrix: np.ndarray) -> np.ndarray:
+ """
+ Convert real-world camera orientation to MuJoCo coordinate system.
+
+ Args:
+ camera_ori_matrix: 3x3 rotation matrix in real-world coordinates
+
+ Returns:
+ Quaternion in WXYZ format for MuJoCo
+ """
+ # Apply coordinate system transformation (flip Y and Z axes)
+ camera_ori_matrix[:, [1, 2]] = -camera_ori_matrix[:, [1, 2]]
+
+ # Convert to quaternion in MuJoCo's WXYZ format
+ r = Rotation.from_matrix(camera_ori_matrix)
+ camera_ori_wxyz = r.as_quat(scalar_first=True)
+ return camera_ori_wxyz
+
+
diff --git a/phantom/phantom/processors/segmentation_processor.py b/phantom/phantom/processors/segmentation_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9adc48f70e4509f9e714f42be59b442405ecd33b
--- /dev/null
+++ b/phantom/phantom/processors/segmentation_processor.py
@@ -0,0 +1,1056 @@
+"""
+Segmentation Processor Module
+
+This module uses SAM2 to create masks of hands and arms in video sequences.
+
+Processing Pipeline:
+1. Load video frames and detection/pose data from previous stages
+2. Initialize segmentation with highest-quality detection frame
+3. Propagate segmentation bidirectionally (forward and reverse)
+4. Combine temporal results for complete sequence coverage
+5. Generate visualization videos and save segmentation masks
+
+The module supports different segmentation modes:
+- HandSegmentationProcessor: Precise hand-only segmentation
+- ArmSegmentationProcessor: Combined hand + arm segmentation
+"""
+
+import os
+import logging
+import shutil
+from tqdm import tqdm
+import numpy as np
+import mediapy as media
+import argparse
+from typing import Dict, Tuple, Optional, List
+
+from phantom.processors.paths import Paths
+from phantom.processors.base_processor import BaseProcessor
+from phantom.detectors.detector_sam2 import DetectorSam2
+from phantom.detectors.detector_detectron2 import DetectorDetectron2
+from phantom.utils.bbox_utils import get_overlap_score
+from phantom.processors.phantom_data import HandSequence
+
+logger = logging.getLogger(__name__)
+
+# Configuration constants for segmentation processing
+DEFAULT_FPS = 10
+DEFAULT_OVERLAP_THRESHOLD = 0.5
+DEFAULT_CODEC = "ffv1"
+ANNOTATION_CODEC = "h264"
+
+class BaseSegmentationProcessor(BaseProcessor):
+ """
+ Base class for video segmentation processing using SAM2.
+
+ The base processor establishes the framework for temporal segmentation processing,
+ where segmentation masks are propagated both forward and backward through time
+ to ensure temporal consistency and complete coverage of the video sequence.
+
+ Attributes:
+ detector_sam (DetectorSam2): SAM2 segmentation model instance
+ """
+ def __init__(self, args: argparse.Namespace) -> None:
+ """
+ Initialize the base segmentation processor.
+
+ Args:
+ args: Command line arguments containing segmentation configuration
+ """
+ super().__init__(args)
+ self.detector_sam = DetectorSam2()
+
+ def process_one_demo(self, data_sub_folder: str) -> None:
+ """
+ Process a single demonstration - to be implemented by subclasses.
+
+ Args:
+ data_sub_folder: Path to demonstration data folder
+
+ Raises:
+ NotImplementedError: Must be implemented by concrete subclasses
+ """
+ raise NotImplementedError("Subclasses must implement this method")
+
+ def _load_hamer_data(self, paths: Paths) -> Dict[str, HandSequence]:
+ """
+ Load hand pose estimation data from previous processing stage.
+
+ Args:
+ paths: Paths object containing file locations
+
+ Returns:
+ Dictionary containing left and right hand sequences
+ """
+ if self.bimanual_setup == "single_arm":
+ if self.target_hand == "left":
+ return {"left": HandSequence.load(paths.hand_data_left)}
+ elif self.target_hand == "right":
+ return {"right": HandSequence.load(paths.hand_data_right)}
+ else:
+ raise ValueError(f"Invalid target hand: {self.target_hand}")
+ elif self.bimanual_setup == "shoulders":
+ return {
+ "left": HandSequence.load(paths.hand_data_left),
+ "right": HandSequence.load(paths.hand_data_right)
+ }
+ else:
+ raise ValueError(f"Invalid bimanual setup: {self.bimanual_setup}")
+
+ @staticmethod
+ def _load_video(video_path: str) -> np.ndarray:
+ """
+ Load and validate video frames from disk.
+
+ Args:
+ video_path: Path to video file
+
+ Returns:
+ Array of RGB video frames
+
+ Raises:
+ FileNotFoundError: If video file doesn't exist
+ ValueError: If video file is empty or corrupted
+ """
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"Video file not found: {video_path}")
+
+ imgs_rgb = media.read_video(video_path)
+ if len(imgs_rgb) == 0:
+ raise ValueError("Empty video file")
+
+ return imgs_rgb
+
+ @staticmethod
+ def _load_bbox_data(bbox_path: str) -> Dict[str, np.ndarray]:
+ """
+ Load and validate bounding box detection data.
+
+ Args:
+ bbox_path: Path to bounding box data file
+
+ Returns:
+ Dictionary containing detection results from bounding box processor
+
+ Raises:
+ FileNotFoundError: If bounding box data file doesn't exist
+ """
+ if not os.path.exists(bbox_path):
+ raise FileNotFoundError(f"Bbox data not found: {bbox_path}")
+
+ return np.load(bbox_path)
+
+ @staticmethod
+ def _combine_sam_images(
+ imgs_rgb: np.ndarray,
+ imgs_forward: Dict[int, np.ndarray],
+ imgs_reverse: Dict[int, np.ndarray]
+ ) -> np.ndarray:
+ """
+ Combine forward and reverse SAM visualization images.
+
+ This method merges the visualization results from bidirectional
+ processing to create a complete visualization sequence.
+
+ Args:
+ imgs_rgb: Original RGB frames for shape reference
+ imgs_forward: Forward propagation visualization results
+ imgs_reverse: Reverse propagation visualization results
+
+ Returns:
+ Combined visualization array
+ """
+ result = np.zeros_like(imgs_rgb)
+ # Fill in forward propagation results
+ for idx in imgs_forward:
+ result[idx] = imgs_forward[idx]
+ # Fill in reverse propagation results (may overwrite forward results)
+ for idx in imgs_reverse:
+ result[idx] = imgs_reverse[idx]
+ return result
+
+ @staticmethod
+ def _combine_masks(
+ imgs_rgb: np.ndarray,
+ masks_forward: Dict[int, np.ndarray],
+ masks_reverse: Dict[int, np.ndarray]
+ ) -> np.ndarray:
+ """
+ Combine forward and reverse segmentation masks.
+
+ This method merges segmentation masks from bidirectional processing
+ to ensure complete temporal coverage of the video sequence.
+
+ Args:
+ imgs_rgb: Original RGB frames for shape reference
+ masks_forward: Forward propagation mask results
+ masks_reverse: Reverse propagation mask results
+
+ Returns:
+ Combined mask array with shape (num_frames, height, width)
+ """
+ result = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
+ for idx in masks_forward:
+ result[idx] = masks_forward[idx][0]
+ for idx in masks_reverse:
+ result[idx] = masks_reverse[idx][0]
+ return result
+
+class ArmSegmentationProcessor(BaseSegmentationProcessor):
+ """
+ Processor for segmenting combined hand and arm regions in video sequences.
+
+ Attributes:
+ detectron_detector (DetectorDetectron2): Detectron2 model for initial detection
+ """
+ def __init__(self, args: argparse.Namespace) -> None:
+ """
+ Initialize the arm segmentation processor with detection models.
+
+ Args:
+ args: Command line arguments containing model configuration
+ """
+ super().__init__(args)
+
+ # Initialize Detectron2 for initial hand/arm detection
+ root_dir = "../submodules/phantom-hamer/"
+ self.detectron_detector = DetectorDetectron2(root_dir)
+
+
+ def process_one_demo(self, data_sub_folder: str, hamer_data: Optional[Dict[str, HandSequence]] = None) -> None:
+ """
+ Process a single video demonstration to generate combined hand + arm segmentation masks.
+
+ Args:
+ data_sub_folder: Path to the subfolder containing the demo data
+ hamer_data: Optional pre-loaded hand pose data for segmentation guidance
+
+ Raises:
+ FileNotFoundError: If required input files are not found
+ ValueError: If video frames or bounding boxes are invalid
+ """
+ # Setup and load all required data
+ save_folder, paths, imgs_rgb, bbox_data, det_bbox_data, hamer_data = self._setup_processing(
+ data_sub_folder, hamer_data
+ )
+
+ # Process based on setup type
+ if self.bimanual_setup == "single_arm":
+ masks = self._process_single_arm(imgs_rgb, bbox_data, det_bbox_data, hamer_data, paths)
+ elif self.bimanual_setup == "shoulders":
+ masks = self._process_bimanual(imgs_rgb, bbox_data, det_bbox_data, hamer_data, paths)
+ else:
+ raise ValueError(f"Invalid bimanual setup: {self.bimanual_setup}")
+
+ # Create visualization and save results
+ sam_imgs = self._create_visualization(imgs_rgb, masks)
+ self._validate_output_consistency(imgs_rgb, masks, sam_imgs)
+ self._save_results(paths, masks, sam_imgs)
+
+ def _setup_processing(
+ self,
+ data_sub_folder: str,
+ hamer_data: Optional[Dict[str, HandSequence]]
+ ) -> Tuple[str, Paths, np.ndarray, Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, HandSequence]]:
+ """
+ Setup processing environment and load all required data.
+
+ Args:
+ data_sub_folder: Path to the subfolder containing the demo data
+ hamer_data: Optional pre-loaded hand pose data
+
+ Returns:
+ Tuple containing: (save_folder, paths, imgs_rgb, bbox_data, det_bbox_data, hamer_data)
+ """
+ save_folder = self.get_save_folder(data_sub_folder)
+ paths = self.get_paths(save_folder)
+ paths._setup_original_images()
+ paths._setup_original_images_reverse()
+
+ # Load and validate all input data
+ imgs_rgb = self._load_video(paths.video_left)
+ bbox_data = self._load_bbox_data(paths.bbox_data)
+ det_bbox_data = self.get_detectron_bboxes(imgs_rgb, bbox_data)
+ if hamer_data is None:
+ hamer_data = self._load_hamer_data(paths)
+
+ return save_folder, paths, imgs_rgb, bbox_data, det_bbox_data, hamer_data
+
+ def _process_single_arm(
+ self,
+ imgs_rgb: np.ndarray,
+ bbox_data: Dict[str, np.ndarray],
+ det_bbox_data: Dict[str, np.ndarray],
+ hamer_data: Dict[str, HandSequence],
+ paths: Paths
+ ) -> np.ndarray:
+ """
+ Process single arm setup (left or right hand only).
+
+ Args:
+ imgs_rgb: RGB video frames
+ bbox_data: Bounding box detection data
+ det_bbox_data: Detectron2 refined bounding boxes
+ hamer_data: Hand pose estimation data
+ paths: Paths object for file management
+
+ Returns:
+ Boolean segmentation masks
+ """
+ if self.target_hand == "left":
+ hand_data = self._process_hand_data(
+ imgs_rgb,
+ bbox_data["left_bboxes"],
+ bbox_data["left_bbox_min_dist_to_edge"],
+ bbox_data["left_hand_detected"],
+ det_bbox_data["left_det_bboxes"],
+ hamer_data["left"],
+ paths,
+ "left"
+ )
+ masks = hand_data["left_masks"].astype(np.bool_)
+ elif self.target_hand == "right":
+ hand_data = self._process_hand_data(
+ imgs_rgb,
+ bbox_data["right_bboxes"],
+ bbox_data["right_bbox_min_dist_to_edge"],
+ bbox_data["right_hand_detected"],
+ det_bbox_data["right_det_bboxes"],
+ hamer_data["right"],
+ paths,
+ "right"
+ )
+ masks = hand_data["right_masks"].astype(np.bool_)
+ else:
+ raise ValueError(f"Invalid target hand: {self.target_hand}")
+
+ return masks.astype(np.bool_)
+
+ def _process_bimanual(
+ self,
+ imgs_rgb: np.ndarray,
+ bbox_data: Dict[str, np.ndarray],
+ det_bbox_data: Dict[str, np.ndarray],
+ hamer_data: Dict[str, HandSequence],
+ paths: Paths
+ ) -> np.ndarray:
+ """
+ Process bimanual setup (both hands combined).
+
+ Args:
+ imgs_rgb: RGB video frames
+ bbox_data: Bounding box detection data
+ det_bbox_data: Detectron2 refined bounding boxes
+ hamer_data: Hand pose estimation data
+ paths: Paths object for file management
+
+ Returns:
+ Combined boolean segmentation masks
+ """
+ # Process left hand with arm segmentation
+ left_data = self._process_hand_data(
+ imgs_rgb,
+ bbox_data["left_bboxes"],
+ bbox_data["left_bbox_min_dist_to_edge"],
+ bbox_data["left_hand_detected"],
+ det_bbox_data["left_det_bboxes"],
+ hamer_data["left"],
+ paths,
+ "left"
+ )
+
+ # Process right hand with arm segmentation
+ right_data = self._process_hand_data(
+ imgs_rgb,
+ bbox_data["right_bboxes"],
+ bbox_data["right_bbox_min_dist_to_edge"],
+ bbox_data["right_hand_detected"],
+ det_bbox_data["right_det_bboxes"],
+ hamer_data["right"],
+ paths,
+ "right"
+ )
+
+ # Convert to boolean masks and combine
+ left_masks = left_data["left_masks"].astype(np.bool_)
+ right_masks = right_data["right_masks"].astype(np.bool_)
+
+ # Generate combined video masks by taking the union of left and right masks
+ masks = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
+ for idx in range(len(imgs_rgb)):
+ masks[idx] = left_masks[idx] | right_masks[idx]
+
+ return masks.astype(np.bool_)
+
+ def _create_visualization(self, imgs_rgb: np.ndarray, masks: np.ndarray) -> np.ndarray:
+ """
+ Create visualization by masking out segmented regions.
+
+ Args:
+ imgs_rgb: Original RGB video frames
+ masks: Boolean segmentation masks
+
+ Returns:
+ Visualization images with masked regions set to black
+ """
+ sam_imgs = []
+ for idx in range(len(imgs_rgb)):
+ img = imgs_rgb[idx].copy() # Create copy to avoid modifying original
+ mask = masks[idx]
+ img[mask] = 0 # Set masked regions to black
+ sam_imgs.append(img)
+ return np.array(sam_imgs)
+
+ def _validate_output_consistency(
+ self,
+ imgs_rgb: np.ndarray,
+ masks: np.ndarray,
+ sam_imgs: np.ndarray
+ ) -> None:
+ """
+ Validate that output arrays have consistent dimensions.
+
+ Args:
+ imgs_rgb: Original RGB video frames
+ masks: Segmentation masks
+ sam_imgs: Visualization images
+
+ Raises:
+ AssertionError: If dimensions don't match
+ """
+ assert len(sam_imgs) == len(imgs_rgb), "Visualization length doesn't match input"
+ assert len(masks) == len(imgs_rgb), "Masks length doesn't match input"
+
+
+ def _process_hand_data(
+ self,
+ imgs_rgb: np.ndarray,
+ bboxes: np.ndarray,
+ bbox_min_dist: np.ndarray,
+ hand_detected: np.ndarray,
+ det_bboxes: np.ndarray,
+ hamer_data: HandSequence,
+ paths: Paths,
+ hand_side: str
+ ) -> Dict[str, np.ndarray]:
+ """
+ Process segmentation data for a single hand (left or right) with arm inclusion.
+
+ Args:
+ imgs_rgb: RGB video frames
+ bboxes: Hand bounding boxes from detection stage
+ bbox_min_dist: Minimum distances to image edges (quality metric)
+ hand_detected: Boolean flags indicating valid hand detections
+ det_bboxes: Refined bounding boxes from Detectron2
+ hamer_data: Hand pose data for segmentation guidance
+ paths: Paths object for file management
+ hand_side: "left" or "right" specifying which hand to process
+
+ Returns:
+ Dictionary containing segmentation masks and visualization images
+ """
+ # Handle cases with no valid detections
+ if not hand_detected.any() or max(bbox_min_dist) == 0:
+ return {
+ f"{hand_side}_masks": np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1])),
+ f"{hand_side}_sam_imgs": np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1], 3))
+ }
+
+ # Extract hand pose keypoints for segmentation guidance
+ kpts_2d = hamer_data.kpts_2d
+
+ # Find the frame with highest quality (furthest from edges)
+ max_dist_idx = np.argmax(bbox_min_dist)
+ points = np.expand_dims(kpts_2d[max_dist_idx], axis=1)
+ bbox_dets = det_bboxes[max_dist_idx]
+
+ # Use original bounding box if Detectron2 detection failed
+ if bbox_dets.sum() == 0:
+ bbox_dets = bboxes[max_dist_idx]
+
+ # Process segmentation in both temporal directions
+ masks_forward, sam_imgs_forward = self._run_sam_segmentation(
+ paths, bbox_dets, points, max_dist_idx, reverse=False
+ )
+ masks_reverse, sam_imgs_reverse = self._run_sam_segmentation(
+ paths, bbox_dets, points, max_dist_idx, reverse=True
+ )
+
+ # Combine bidirectional results
+ sam_imgs = self._combine_sam_images(imgs_rgb, sam_imgs_forward, sam_imgs_reverse)
+ masks = self._combine_masks(imgs_rgb, masks_forward, masks_reverse)
+
+ return {
+ f"{hand_side}_masks": masks,
+ f"{hand_side}_sam_imgs": sam_imgs
+ }
+
+ def _run_sam_segmentation(
+ self,
+ paths: Paths,
+ bbox_dets: np.ndarray,
+ points: np.ndarray,
+ max_dist_idx: int,
+ reverse: bool
+ ) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
+ """
+ Process video segmentation in either forward or reverse temporal direction.
+
+ Args:
+ paths: Paths object for file management
+ bbox_dets: Detectron2 bounding box for initialization
+ points: Hand keypoints for segmentation guidance
+ max_dist_idx: Index of highest-quality frame for initialization
+ reverse: Whether to process in reverse temporal order
+
+ Returns:
+ Tuple of (segmentation_masks, visualization_images)
+ """
+ return self.detector_sam.segment_video(
+ paths.original_images_folder,
+ bbox_dets,
+ points,
+ [max_dist_idx],
+ reverse=reverse
+ )
+
+ def get_detectron_bboxes(self, imgs_rgb: np.ndarray, bbox_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+ """
+ Generate enhanced bounding boxes using Detectron2 for improved segmentation.
+
+ Args:
+ imgs_rgb: Array of RGB frames with shape (N, H, W, 3)
+ bbox_data: Initial bounding box data from hand detection stage containing:
+ - left_bboxes: Left hand bounding boxes
+ - right_bboxes: Right hand bounding boxes
+ - left_hand_detected: Boolean flags for left hand detection
+ - right_hand_detected: Boolean flags for right hand detection
+ - left_bbox_min_dist_to_edge: Quality metrics for left hand
+ - right_bbox_min_dist_to_edge: Quality metrics for right hand
+
+ Returns:
+ Dictionary containing refined bounding boxes:
+ - left_det_bboxes: Enhanced left hand bounding boxes
+ - right_det_bboxes: Enhanced right hand bounding boxes
+
+ Raises:
+ ValueError: If input array is empty or has incorrect shape
+ """
+ self._validate_detectron_input(imgs_rgb)
+
+ # Extract detection data and initialize output arrays
+ detection_data = self._extract_detection_data(bbox_data)
+ left_det_bboxes, right_det_bboxes = self._initialize_bbox_arrays(imgs_rgb)
+
+ # Process only highest-quality frames for efficiency
+ idx_list = self._get_quality_frame_indices(bbox_data)
+
+ for idx in tqdm(idx_list, desc="Processing frames"):
+ try:
+ self._process_detectron_frame(
+ idx, imgs_rgb, detection_data, left_det_bboxes, right_det_bboxes
+ )
+ except Exception as e:
+ logging.error(f"Error processing frame {idx}: {str(e)}")
+
+ return {"left_det_bboxes": left_det_bboxes, "right_det_bboxes": right_det_bboxes}
+
+ def _validate_detectron_input(self, imgs_rgb: np.ndarray) -> None:
+ """
+ Validate input array for Detectron2 processing.
+
+ Args:
+ imgs_rgb: Array of RGB frames
+
+ Raises:
+ ValueError: If input array is empty or has incorrect shape
+ """
+ if len(imgs_rgb) == 0:
+ raise ValueError("Empty input array - no video frames provided")
+
+ if len(imgs_rgb.shape) != 4 or imgs_rgb.shape[-1] != 3:
+ raise ValueError(f"Expected input shape (N, H, W, 3), got {imgs_rgb.shape}. "
+ f"Input should be RGB video frames.")
+
+ def _extract_detection_data(self, bbox_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+ """
+ Extract detection data from bounding box data.
+
+ Args:
+ bbox_data: Bounding box detection data
+
+ Returns:
+ Dictionary containing extracted detection data
+ """
+ return {
+ "left_bboxes": bbox_data["left_bboxes"],
+ "right_bboxes": bbox_data["right_bboxes"],
+ "left_hand_detected": bbox_data["left_hand_detected"],
+ "right_hand_detected": bbox_data["right_hand_detected"]
+ }
+
+ def _initialize_bbox_arrays(self, imgs_rgb: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Initialize output bounding box arrays.
+
+ Args:
+ imgs_rgb: RGB video frames for shape reference
+
+ Returns:
+ Tuple of (left_det_bboxes, right_det_bboxes) initialized arrays
+ """
+ left_det_bboxes = np.zeros((len(imgs_rgb), 4))
+ right_det_bboxes = np.zeros((len(imgs_rgb), 4))
+ return left_det_bboxes, right_det_bboxes
+
+ def _get_quality_frame_indices(self, bbox_data: Dict[str, np.ndarray]) -> List[int]:
+ """
+ Get indices of highest-quality frames for processing.
+
+ Args:
+ bbox_data: Bounding box detection data
+
+ Returns:
+ List of frame indices to process
+ """
+ idx_left = np.argmax(bbox_data["left_bbox_min_dist_to_edge"])
+ idx_right = np.argmax(bbox_data["right_bbox_min_dist_to_edge"])
+ return [idx_left, idx_right]
+
+ def _process_detectron_frame(
+ self,
+ idx: int,
+ imgs_rgb: np.ndarray,
+ detection_data: Dict[str, np.ndarray],
+ left_det_bboxes: np.ndarray,
+ right_det_bboxes: np.ndarray
+ ) -> None:
+ """
+ Process a single frame with Detectron2 detection.
+
+ Args:
+ idx: Frame index to process
+ imgs_rgb: RGB video frames
+ detection_data: Extracted detection data
+ left_det_bboxes: Left hand bounding box output array
+ right_det_bboxes: Right hand bounding box output array
+ """
+ left_hand_detected = detection_data["left_hand_detected"]
+ right_hand_detected = detection_data["right_hand_detected"]
+
+ # Skip frames without any hand detections
+ if not left_hand_detected[idx] and not right_hand_detected[idx]:
+ left_det_bboxes[idx] = np.array([0, 0, 0, 0])
+ right_det_bboxes[idx] = np.array([0, 0, 0, 0])
+ return
+
+ # Apply Detectron2 detection
+ img = imgs_rgb[idx]
+ det_bboxes, det_scores = self.detectron_detector.get_bboxes(img, visualize=False)
+
+ if len(det_bboxes) == 0:
+ return
+
+ # Match left hand detection with Detectron2 results
+ if left_hand_detected[idx]:
+ self._match_hand_detection(
+ idx, "left", detection_data, det_bboxes, left_det_bboxes
+ )
+
+ # Match right hand detection with Detectron2 results
+ if right_hand_detected[idx]:
+ self._match_hand_detection(
+ idx, "right", detection_data, det_bboxes, right_det_bboxes
+ )
+
+ def _match_hand_detection(
+ self,
+ idx: int,
+ hand_side: str,
+ detection_data: Dict[str, np.ndarray],
+ det_bboxes: np.ndarray,
+ output_bboxes: np.ndarray
+ ) -> None:
+ """
+ Match hand detection with Detectron2 results using overlap scores.
+
+ Args:
+ idx: Frame index
+ hand_side: "left" or "right" hand
+ detection_data: Extracted detection data
+ det_bboxes: Detectron2 detection results
+ output_bboxes: Output bounding box array to update
+ """
+ bbox = detection_data[f"{hand_side}_bboxes"][idx]
+ overlap_scores = []
+
+ for det_bbox in det_bboxes:
+ overlap_score = get_overlap_score(bbox, det_bbox)
+ overlap_scores.append(overlap_score)
+
+ if np.max(overlap_scores) > DEFAULT_OVERLAP_THRESHOLD:
+ best_idx = np.argmax(overlap_scores)
+ output_bboxes[idx] = det_bboxes[best_idx].astype(np.int32)
+
+ @staticmethod
+ def _save_results(
+ paths: Paths,
+ masks: np.ndarray,
+ sam_imgs: np.ndarray,
+ fps: int = DEFAULT_FPS
+ ) -> None:
+ """
+ Save arm segmentation results to disk.
+
+ Args:
+ paths: Paths object containing output file locations
+ masks: Combined arm segmentation masks
+ sam_imgs: SAM visualization images
+ fps: Frames per second for output videos (default: 10)
+ """
+ ArmSegmentationProcessor._create_output_directory(paths)
+
+ try:
+ ArmSegmentationProcessor._save_mask_data(paths, masks)
+ ArmSegmentationProcessor._create_videos(paths, masks, sam_imgs, fps)
+ except Exception as e:
+ logging.error(f"Error saving results: {str(e)}")
+ raise
+
+ ArmSegmentationProcessor._cleanup_temp_files(paths)
+ ArmSegmentationProcessor._update_annotation_video(paths, masks, sam_imgs, fps)
+
+ @staticmethod
+ def _create_output_directory(paths: Paths) -> None:
+ """
+ Create output directory for segmentation results.
+
+ Args:
+ paths: Paths object containing output directory location
+ """
+ if not os.path.exists(paths.segmentation_processor):
+ os.makedirs(paths.segmentation_processor)
+
+ @staticmethod
+ def _save_mask_data(paths: Paths, masks: np.ndarray) -> None:
+ """
+ Save mask data to disk.
+
+ Args:
+ paths: Paths object containing output file locations
+ masks: Segmentation masks to save
+ """
+ np.save(paths.masks_arm, masks)
+
+ @staticmethod
+ def _create_videos(paths: Paths, masks: np.ndarray, sam_imgs: np.ndarray, fps: int) -> None:
+ """
+ Create visualization videos from masks and SAM images.
+
+ Args:
+ paths: Paths object containing output file locations
+ masks: Segmentation masks
+ sam_imgs: SAM visualization images
+ fps: Frames per second for output videos
+ """
+ for name, data in [
+ ("video_masks_arm", masks),
+ ("video_sam_arm", sam_imgs),
+ ]:
+ output_path = getattr(paths, name)
+ media.write_video(output_path, data, fps=fps, codec=DEFAULT_CODEC)
+
+ @staticmethod
+ def _cleanup_temp_files(paths: Paths) -> None:
+ """
+ Clean up temporary directories created during processing.
+
+ Args:
+ paths: Paths object containing temporary directory locations
+ """
+ if os.path.exists(paths.original_images_folder):
+ shutil.rmtree(paths.original_images_folder)
+ if os.path.exists(paths.original_images_folder_reverse):
+ shutil.rmtree(paths.original_images_folder_reverse)
+
+ @staticmethod
+ def _update_annotation_video(paths: Paths, masks: np.ndarray, sam_imgs: np.ndarray, fps: int) -> None:
+ """
+ Update existing annotation video with segmentation results.
+
+ Args:
+ paths: Paths object containing annotation video location
+ masks: Segmentation masks
+ sam_imgs: SAM visualization images
+ fps: Frames per second for output video
+ """
+ if os.path.exists(paths.video_annot):
+ annot_imgs = media.read_video(paths.video_annot)
+ for idx in range(len(annot_imgs)):
+ annot_img = annot_imgs[idx]
+ h = masks[idx].shape[0]
+ w = masks[idx].shape[1]
+ # Insert segmentation visualization in the top-right quadrant
+ annot_img[:h, w:, :] = sam_imgs[idx]
+ media.write_video(paths.video_annot, annot_imgs, fps=fps, codec=ANNOTATION_CODEC)
+
+
+
+class HandSegmentationProcessor(BaseSegmentationProcessor):
+ """
+ Processor for precise hand-only segmentation in video sequences.
+
+ Attributes:
+ Inherits detector_sam from BaseSegmentationProcessor
+ """
+ def __init__(self, args: argparse.Namespace) -> None:
+ """
+ Initialize the hand segmentation processor.
+
+ Args:
+ args: Command line arguments containing segmentation configuration
+ """
+ super().__init__(args)
+
+ def process_one_demo(self, data_sub_folder: str, hamer_data: Optional[Dict[str, HandSequence]] = None) -> None:
+ """
+ Process a single video demonstration to generate precise hand segmentation masks.
+
+ Args:
+ data_sub_folder: Path to the subfolder containing the demo data
+ hamer_data: Optional pre-loaded hand pose data for segmentation guidance
+
+ Raises:
+ FileNotFoundError: If required input files are not found
+ ValueError: If video frames or bounding boxes are invalid
+ """
+ save_folder = self.get_save_folder(data_sub_folder)
+
+ paths = self.get_paths(save_folder)
+ paths._setup_original_images()
+ paths._setup_original_images_reverse()
+
+ # Load and validate input data
+ imgs_rgb = self._load_video(paths.video_left)
+ bbox_data = self._load_bbox_data(paths.bbox_data)
+ if hamer_data is None:
+ hamer_data = self._load_hamer_data(paths)
+
+ # Process left and right hands separately for precise segmentation
+ left_data = self._process_hand_data(
+ imgs_rgb,
+ bbox_data["left_bboxes"],
+ bbox_data["left_bbox_min_dist_to_edge"],
+ bbox_data["left_hand_detected"],
+ hamer_data["left"],
+ paths,
+ "left"
+ )
+
+ right_data = self._process_hand_data(
+ imgs_rgb,
+ bbox_data["right_bboxes"],
+ bbox_data["right_bbox_min_dist_to_edge"],
+ bbox_data["right_hand_detected"],
+ hamer_data["right"],
+ paths,
+ "right"
+ )
+
+ # Convert to boolean masks
+ left_masks = left_data["left_masks"].astype(np.bool_)
+ left_sam_imgs = left_data["left_sam_imgs"]
+ right_masks = right_data["right_masks"].astype(np.bool_)
+ right_sam_imgs = right_data["right_sam_imgs"]
+
+ # Save results with separate left/right hand data
+ self._save_results(paths, left_masks, left_sam_imgs, right_masks, right_sam_imgs)
+
+
+ def _process_hand_data(
+ self,
+ imgs_rgb: np.ndarray,
+ bboxes: np.ndarray,
+ bbox_min_dist: np.ndarray,
+ hand_detected: np.ndarray,
+ hamer_data: HandSequence,
+ paths: Paths,
+ hand_side: str
+ ) -> Dict[str, np.ndarray]:
+ """
+ Process hand segmentation data for a single hand (left or right).
+
+ Args:
+ imgs_rgb: RGB video frames
+ bboxes: Hand bounding boxes from detection stage
+ bbox_min_dist: Minimum distances to image edges (quality metric)
+ hand_detected: Boolean flags indicating valid hand detections
+ hamer_data: Hand pose data for segmentation guidance
+ paths: Paths object for file management
+ hand_side: "left" or "right" specifying which hand to process
+
+ Returns:
+ Dictionary containing segmentation masks and visualization images
+ """
+ # Handle cases with no valid detections
+ if not hand_detected.any() or max(bbox_min_dist) == 0:
+ return {
+ f"{hand_side}_masks": np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1])),
+ f"{hand_side}_sam_imgs": np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1], 3))
+ }
+
+ # Extract hand pose keypoints for segmentation guidance
+ kpts_2d = hamer_data.kpts_2d
+
+ # Find the frame with highest quality (furthest from edges)
+ max_dist_idx = np.argmax(bbox_min_dist)
+ bbox = bboxes[max_dist_idx]
+ points = np.expand_dims(kpts_2d[max_dist_idx], axis=1)
+
+ # Process segmentation in both temporal directions
+ masks_forward, sam_imgs_forward = self._run_sam_segmentation(
+ paths, bbox, points, max_dist_idx, reverse=False, output_bboxes=bboxes
+ )
+ masks_reverse, sam_imgs_reverse = self._run_sam_segmentation(
+ paths, bbox, points, max_dist_idx, reverse=True, output_bboxes=bboxes
+ )
+
+ # Combine bidirectional results
+ sam_imgs = self._combine_sam_images(imgs_rgb, sam_imgs_forward, sam_imgs_reverse)
+ masks = self._combine_masks(imgs_rgb, masks_forward, masks_reverse)
+
+ return {
+ f"{hand_side}_masks": masks,
+ f"{hand_side}_sam_imgs": sam_imgs
+ }
+
+
+ def _run_sam_segmentation(
+ self,
+ paths: Paths,
+ bbox: np.ndarray,
+ points: np.ndarray,
+ max_dist_idx: int,
+ reverse: bool,
+ output_bboxes: np.ndarray
+ ) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
+ """
+ Process video segmentation in either forward or reverse temporal direction.
+
+ Args:
+ paths: Paths object for file management
+ bbox: Initial bounding box for segmentation
+ points: Hand keypoints for segmentation guidance
+ max_dist_idx: Index of highest-quality frame for initialization
+ reverse: Whether to process in reverse temporal order
+ output_bboxes: All bounding boxes for the sequence
+
+ Returns:
+ Tuple of (segmentation_masks, visualization_images)
+ """
+ return self.detector_sam.segment_video(
+ paths.original_images_folder,
+ bbox,
+ points,
+ [max_dist_idx],
+ reverse=reverse,
+ output_bboxes=output_bboxes
+ )
+
+ @staticmethod
+ def _save_results(
+ paths: Paths,
+ left_masks: np.ndarray,
+ left_sam_imgs: np.ndarray,
+ right_masks: np.ndarray,
+ right_sam_imgs: np.ndarray,
+ fps: int = DEFAULT_FPS
+ ) -> None:
+ """
+ Save hand segmentation results to disk.
+
+ Args:
+ paths: Paths object containing output file locations
+ left_masks: Left hand segmentation masks
+ left_sam_imgs: Left hand SAM visualization images
+ right_masks: Right hand segmentation masks
+ right_sam_imgs: Right hand SAM visualization images
+ fps: Frames per second for output videos (default: 10)
+ """
+ HandSegmentationProcessor._create_output_directory(paths)
+
+ try:
+ HandSegmentationProcessor._save_hand_mask_data(paths, left_masks, right_masks)
+ HandSegmentationProcessor._create_hand_videos(paths, left_masks, left_sam_imgs, right_masks, right_sam_imgs, fps)
+ except Exception as e:
+ logging.error(f"Error saving results: {str(e)}")
+ raise
+
+ HandSegmentationProcessor._cleanup_temp_files(paths)
+
+ @staticmethod
+ def _create_output_directory(paths: Paths) -> None:
+ """
+ Create output directory for segmentation results.
+
+ Args:
+ paths: Paths object containing output directory location
+ """
+ if not os.path.exists(paths.segmentation_processor):
+ os.makedirs(paths.segmentation_processor)
+
+ @staticmethod
+ def _save_hand_mask_data(paths: Paths, left_masks: np.ndarray, right_masks: np.ndarray) -> None:
+ """
+ Save hand mask data to disk.
+
+ Args:
+ paths: Paths object containing output file locations
+ left_masks: Left hand segmentation masks
+ right_masks: Right hand segmentation masks
+ """
+ np.save(paths.masks_hand_left, left_masks)
+ np.save(paths.masks_hand_right, right_masks)
+
+ @staticmethod
+ def _create_hand_videos(
+ paths: Paths,
+ left_masks: np.ndarray,
+ left_sam_imgs: np.ndarray,
+ right_masks: np.ndarray,
+ right_sam_imgs: np.ndarray,
+ fps: int
+ ) -> None:
+ """
+ Create visualization videos for hand segmentation.
+
+ Args:
+ paths: Paths object containing output file locations
+ left_masks: Left hand segmentation masks
+ left_sam_imgs: Left hand SAM visualization images
+ right_masks: Right hand segmentation masks
+ right_sam_imgs: Right hand SAM visualization images
+ fps: Frames per second for output videos
+ """
+ for name, data in [
+ ("video_masks_hand_left", left_masks),
+ ("video_masks_hand_right", right_masks),
+ ("video_sam_hand_left", left_sam_imgs),
+ ("video_sam_hand_right", right_sam_imgs),
+ ]:
+ output_path = getattr(paths, name)
+ media.write_video(output_path, data, fps=fps, codec=DEFAULT_CODEC)
+
+ @staticmethod
+ def _cleanup_temp_files(paths: Paths) -> None:
+ """
+ Clean up temporary directories created during processing.
+
+ Args:
+ paths: Paths object containing temporary directory locations
+ """
+ if os.path.exists(paths.original_images_folder):
+ shutil.rmtree(paths.original_images_folder)
+ if os.path.exists(paths.original_images_folder_reverse):
+ shutil.rmtree(paths.original_images_folder_reverse)
+
diff --git a/phantom/phantom/processors/smoothing_processor.py b/phantom/phantom/processors/smoothing_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ca2b9b325dc0cca10473bd515420b7ac3ac82c7
--- /dev/null
+++ b/phantom/phantom/processors/smoothing_processor.py
@@ -0,0 +1,303 @@
+"""
+Trajectory Smoothing Processor Module
+
+This module does trajectory smoothing for end-effector positions, orientations, and gripper states
+extracted from human demonstrations.
+
+Processing Pipeline:
+1. Load processed action data from previous pipeline stages
+2. Apply Gaussian Process smoothing to 3D position trajectories
+3. Apply SLERP-based smoothing to rotation matrix trajectories
+4. Apply Gaussian Process smoothing to gripper distance trajectories
+5. Save smoothed trajectories for robot execution
+"""
+
+import os
+from typing import Optional
+import argparse
+import numpy as np
+import logging
+from sklearn.gaussian_process import GaussianProcessRegressor # type: ignore
+from sklearn.gaussian_process.kernels import RBF, WhiteKernel # type: ignore
+from scipy.spatial.transform import Rotation, Slerp
+
+from phantom.processors.base_processor import BaseProcessor
+from phantom.processors.paths import Paths
+
+logger = logging.getLogger(__name__)
+
+def gaussian_kernel(size: int, sigma: float) -> np.ndarray:
+ """
+ Generate a centered Gaussian kernel for local smoothing operations.
+
+ Args:
+ size: Size of the kernel (should be odd for proper centering)
+ sigma: Standard deviation of the Gaussian distribution
+
+ Returns:
+ Normalized Gaussian kernel array
+
+ Raises:
+ ValueError: If size is not positive
+ """
+ if size <= 0:
+ raise ValueError("Kernel size must be positive")
+
+ x = np.arange(size) - size // 2
+ kernel = np.exp(-0.5 * (x / sigma) ** 2)
+ return kernel / kernel.sum()
+
+class SmoothingProcessor(BaseProcessor):
+ """
+ This processor takes raw trajectory data extracted from human demonstrations
+ and applies smoothing techniques to create executable robot trajectories.
+
+ Attributes:
+ bimanual_setup (str): Configuration mode ("single_arm" or bimanual type)
+ target_hand (str): Target hand for single-arm processing ("left" or "right")
+ """
+ def __init__(self, args: argparse.Namespace) -> None:
+ """
+ Initialize the smoothing processor with configuration parameters.
+
+ Args:
+ args: Command line arguments containing smoothing configuration
+ including bimanual setup and target hand specification
+ """
+ super().__init__(args)
+
+ def process_one_demo(self, data_sub_folder: str) -> None:
+ """
+ Process and smooth trajectories for a single demonstration.
+
+ Args:
+ data_sub_folder: Path to demonstration data folder containing
+ processed action trajectories from previous stages
+ """
+ save_folder = self.get_save_folder(data_sub_folder)
+ paths = self.get_paths(save_folder)
+
+ # Handle single-arm processing mode
+ if self.bimanual_setup == "single_arm":
+ self._process_single_arm_demo(paths)
+ else:
+ self._process_bimanual_demo(paths)
+
+ def _process_single_arm_demo(self, paths: Paths) -> None:
+ """
+ Process single-arm demonstration data.
+
+ Args:
+ paths: Paths object containing file locations
+ """
+ # Load action data for target hand
+ actions_path = self._get_actions_path(paths)
+ actions = np.load(actions_path, allow_pickle=True)
+
+ # Apply smoothing to each trajectory component
+ smoothed_ee_pts = self.gaussian_process_smoothing(actions["ee_pts"])
+
+ # Apply rotation smoothing with configuration-specific parameters
+ if self.constrained_hand:
+ smoothed_ee_oris = self.gaussian_slerp_smoothing(
+ actions["ee_oris"], sigma=10.0, kernel_size=41
+ )
+ else:
+ smoothed_ee_oris = self.gaussian_slerp_smoothing(
+ actions["ee_oris"], sigma=10.0
+ )
+
+ smoothed_ee_widths = self.gaussian_process_smoothing(actions["ee_widths"])
+
+ # Save results based on target hand
+ if self.target_hand == "left":
+ self._save_results(paths, smoothed_ee_pts_left=smoothed_ee_pts,
+ smoothed_ee_oris_left=smoothed_ee_oris,
+ smoothed_ee_widths_left=smoothed_ee_widths)
+ else:
+ self._save_results(paths, smoothed_ee_pts_right=smoothed_ee_pts,
+ smoothed_ee_oris_right=smoothed_ee_oris,
+ smoothed_ee_widths_right=smoothed_ee_widths)
+
+ def _process_bimanual_demo(self, paths: Paths) -> None:
+ """
+ Process bimanual demonstration data.
+
+ Args:
+ paths: Paths object containing file locations
+ """
+ # Load data for both hands
+ actions_left_path = str(paths.actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ actions_right_path = str(paths.actions_right).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ actions_left = np.load(actions_left_path, allow_pickle=True)
+ actions_right = np.load(actions_right_path, allow_pickle=True)
+
+ # Apply position smoothing using Gaussian Process regression
+ smoothed_ee_pts_left = self.gaussian_process_smoothing(actions_left["ee_pts"])
+ smoothed_ee_pts_right = self.gaussian_process_smoothing(actions_right["ee_pts"])
+
+ # Apply rotation smoothing using SLERP with optimized parameters for bimanual coordination
+ smoothed_ee_oris_left = self.gaussian_slerp_smoothing(
+ actions_left["ee_oris"], sigma=10.0, kernel_size=21
+ )
+ smoothed_ee_oris_right = self.gaussian_slerp_smoothing(
+ actions_right["ee_oris"], sigma=10.0, kernel_size=21
+ )
+
+ # Apply gripper distance smoothing
+ smoothed_ee_widths_left = self.gaussian_process_smoothing(actions_left["ee_widths"])
+ smoothed_ee_widths_right = self.gaussian_process_smoothing(actions_right["ee_widths"])
+
+ # Save all smoothed trajectories
+ self._save_results(paths, smoothed_ee_pts_left, smoothed_ee_oris_left, smoothed_ee_widths_left,
+ smoothed_ee_pts_right, smoothed_ee_oris_right, smoothed_ee_widths_right)
+
+ def _get_actions_path(self, paths: Paths) -> str:
+ """
+ Get the appropriate actions file path based on target hand.
+
+ Args:
+ paths: Paths object containing file locations
+
+ Returns:
+ Path to the actions file for the target hand
+ """
+ if self.target_hand == "left":
+ base_path = str(paths.actions_left)
+ else:
+ base_path = str(paths.actions_right)
+ return base_path.split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+
+ def _save_results(self, paths: Paths, smoothed_ee_pts_left: Optional[np.ndarray] = None,
+ smoothed_ee_oris_left: Optional[np.ndarray] = None,
+ smoothed_ee_widths_left: Optional[np.ndarray] = None,
+ smoothed_ee_pts_right: Optional[np.ndarray] = None,
+ smoothed_ee_oris_right: Optional[np.ndarray] = None,
+ smoothed_ee_widths_right: Optional[np.ndarray] = None) -> None:
+ """
+ Save smoothed trajectory results to disk.
+
+ Args:
+ paths: Paths object containing output file locations
+ smoothed_ee_pts_left: Smoothed left hand position trajectory
+ smoothed_ee_oris_left: Smoothed left hand orientation trajectory
+ smoothed_ee_widths_left: Smoothed left hand gripper trajectory
+ smoothed_ee_pts_right: Smoothed right hand position trajectory
+ smoothed_ee_oris_right: Smoothed right hand orientation trajectory
+ smoothed_ee_widths_right: Smoothed right hand gripper trajectory
+ """
+ # Create output directory
+ os.makedirs(paths.smoothing_processor, exist_ok=True)
+
+ # Save left hand trajectories if provided
+ if smoothed_ee_pts_left is not None:
+ smoothed_actions_left_path = str(paths.smoothed_actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ np.savez(smoothed_actions_left_path,
+ ee_pts=smoothed_ee_pts_left,
+ ee_oris=smoothed_ee_oris_left,
+ ee_widths=smoothed_ee_widths_left)
+
+ # Save right hand trajectories if provided
+ if smoothed_ee_pts_right is not None:
+ smoothed_actions_right_path = str(paths.smoothed_actions_right).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
+ np.savez(smoothed_actions_right_path,
+ ee_pts=smoothed_ee_pts_right,
+ ee_oris=smoothed_ee_oris_right,
+ ee_widths=smoothed_ee_widths_right)
+
+ @staticmethod
+ def gaussian_slerp_smoothing(rot_mats: np.ndarray, sigma: float = 2, kernel_size: int = 9) -> np.ndarray:
+ """
+ Apply Gaussian-weighted SLERP smoothing to rotation matrices.
+
+ Args:
+ rot_mats: Array of rotation matrices to smooth, shape (N, 3, 3)
+ sigma: Standard deviation for Gaussian kernel
+ kernel_size: Size of the smoothing kernel (should be odd)
+
+ Returns:
+ Array of smoothed rotation matrices, shape (N, 3, 3)
+
+ Raises:
+ ValueError: If kernel_size is not odd
+ """
+ if kernel_size % 2 != 1:
+ raise ValueError("Kernel size must be odd for proper centering")
+
+ half_k = kernel_size // 2
+ N = len(rot_mats)
+
+ # Step 1: Convert rotation matrices to quaternions for interpolation
+ quats = Rotation.from_matrix(rot_mats).as_quat()
+
+ # Step 2: Apply hemisphere correction to ensure quaternion continuity
+ quats_fixed = [quats[0]]
+ for i in range(1, N):
+ q = quats[i]
+ # Choose quaternion hemisphere that minimizes distance to previous quaternion
+ if np.dot(q, quats_fixed[-1]) < 0:
+ q = -q
+ quats_fixed.append(q)
+ quats_fixed = np.array(quats_fixed)
+
+ # Step 3: Prepare normalized Gaussian weights for local smoothing
+ weights = gaussian_kernel(kernel_size, sigma)
+
+ # Step 4: Apply weighted SLERP averaging for each time point
+ smoothed_rots = []
+ for i in range(N):
+ # Define local neighborhood around current time point
+ start = max(0, i - half_k)
+ end = min(N, i + half_k + 1)
+
+ # Extract local quaternions and corresponding weights
+ local_quats = quats_fixed[start:end]
+ local_weights = weights[half_k - (i - start): half_k + (end - i)]
+
+ # Normalize weights for current neighborhood
+ local_weights /= local_weights.sum()
+
+ # Initialize weighted average with first quaternion
+ q_avg = local_quats[0]
+ r_avg = Rotation.from_quat(q_avg)
+
+ # Iteratively apply weighted SLERP interpolation
+ for j in range(1, len(local_quats)):
+ r_next = Rotation.from_quat(local_quats[j])
+ # Use SLERP with weight proportional to current quaternion's contribution
+ r_avg = Slerp([0, 1], Rotation.concatenate([r_avg, r_next]))([local_weights[j] / (local_weights[:j+1].sum())])[0]
+
+ smoothed_rots.append(r_avg.as_matrix())
+
+ return np.stack(smoothed_rots)
+
+ @staticmethod
+ def gaussian_process_smoothing(pts: np.ndarray) -> np.ndarray:
+ """
+ Apply Gaussian process smoothing to trajectory points.
+
+ Args:
+ pts: Trajectory points to smooth, shape (N,) for 1D or (N, D) for multi-dimensional
+
+ Returns:
+ Smoothed trajectory points with same shape as input
+
+ Raises:
+ ValueError: If pts is empty
+ """
+ if len(pts) == 0:
+ raise ValueError("Cannot smooth empty trajectory")
+
+ # Create time indices as features for GP regression
+ time = np.arange(len(pts))[:, None] # Time as a single feature
+
+ # Configure GP kernel: RBF for smoothness + White noise for robustness
+ kernel = RBF(length_scale=1) + WhiteKernel(noise_level=1)
+ gpr = GaussianProcessRegressor(kernel=kernel, normalize_y=True)
+
+ # Handle 1D trajectory case
+ if pts.ndim == 1:
+ return gpr.fit(time, pts).predict(time)
+
+ # Handle multi-dimensional trajectory case by processing each dimension independently
+ return np.column_stack([gpr.fit(time, pts[:, i]).predict(time) for i in range(pts.shape[1])])
\ No newline at end of file
diff --git a/phantom/phantom/twin_bimanual_robot.py b/phantom/phantom/twin_bimanual_robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..48dae775515e4b785a5b0a97c496a4c2373db6e7
--- /dev/null
+++ b/phantom/phantom/twin_bimanual_robot.py
@@ -0,0 +1,597 @@
+"""
+Virtual twin bimanual robot implementation for MuJoCo simulation.
+
+This module provides a TwinBimanualRobot class that creates a virtual representation
+of a bimanual (two-arm) robot system in MuJoCo using the robosuite framework.
+The twin robot can be controlled via end-effector poses or joint positions and
+provides observation data including RGB images, depth maps, and robot masks.
+"""
+
+from collections import deque
+import re
+import cv2
+import pdb
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.spatial.transform import Rotation
+from dataclasses import dataclass
+from typing import Tuple, Union, Any
+
+from robosuite.controllers import load_controller_config # type: ignore
+from robosuite.utils.camera_utils import get_real_depth_map # type: ignore
+from robomimic.envs.env_robosuite import EnvRobosuite # type: ignore
+import robomimic.utils.obs_utils as ObsUtils # type: ignore
+
+
+@dataclass
+class MujocoCameraParams:
+ """
+ Camera parameters for MuJoCo simulation.
+
+ Attributes:
+ name: Camera name identifier
+ pos: 3D position of camera in world coordinates
+ ori_wxyz: Camera orientation as quaternion (w, x, y, z)
+ fov: Field of view in degrees
+ resolution: Image resolution as (width, height)
+ sensorsize: Physical sensor size in mm
+ principalpixel: Principal point coordinates in pixels
+ focalpixel: Focal length in pixels
+ """
+ name: str
+ pos: np.ndarray
+ ori_wxyz: np.ndarray
+ fov: float
+ resolution: Tuple[int, int]
+ sensorsize: np.ndarray
+ principalpixel: np.ndarray
+ focalpixel: np.ndarray
+
+# Color constants for visualization (RGBA format)
+THUMB_COLOR = [0, 1, 0, 1] # Green for thumb
+INDEX_COLOR = [1, 0, 0, 1] # Red for index finger
+HAND_EE_COLOR = [0, 0, 1, 1] # Blue for hand end-effector
+
+# Transformation matrix for Epic Kitchen setup - converts from base frame to robot frame
+BASE_T_1 = np.array([[0.0, -1.0, 0.0, 0.0],
+ [ 0.5, 0.0, 0.866, 0.2],
+ [-0.866, 0.0, 0.5, 1.50],
+ [ 0.0, 0.0, 0.0, 1.0]])
+
+def convert_real_camera_ori_to_mujoco(camera_ori_matrix: np.ndarray) -> np.ndarray:
+ """
+ Convert camera orientation from real world to MuJoCo XML format.
+
+ MuJoCo uses a different coordinate system convention, so we need to
+ flip the Y and Z axes of the rotation matrix before converting to quaternion.
+
+ Args:
+ camera_ori_matrix: 3x3 rotation matrix in real-world coordinates
+
+ Returns:
+ Camera orientation as quaternion in MuJoCo format (w, x, y, z)
+ """
+ camera_ori_matrix[:, [1, 2]] = -camera_ori_matrix[:, [1, 2]]
+ r = Rotation.from_matrix(camera_ori_matrix)
+ camera_ori_wxyz = r.as_quat(scalar_first=True)
+ return camera_ori_wxyz
+
+class TwinBimanualRobot:
+ """
+ Virtual twin of a bimanual robot system in MuJoCo simulation.
+
+ This class creates a simulated bimanual robot that can be controlled via
+ end-effector poses or joint positions. It provides functionality for:
+ - Robot pose control (OSC or joint-level)
+ - Camera observation collection (RGB, depth, segmentation)
+ - Robot and gripper mask generation
+ - Observation history management
+ """
+
+ def __init__(self, robot_name: str, gripper_name: str, bimanual_setup: str,
+ camera_params: MujocoCameraParams, camera_height: int, camera_width: int,
+ render: bool, n_steps_short: int, n_steps_long: int, square: bool = False,
+ debug_cameras: list[str] = [], epic: bool = False, joint_controller: bool = False):
+ """
+ Initialize the bimanual robot twin.
+
+ Args:
+ robot_name: Type of robot (e.g., "Kinova3")
+ gripper_name: Type of gripper (e.g., "Robotiq85")
+ bimanual_setup: Configuration for bimanual setup
+ camera_params: Camera configuration parameters
+ camera_height: Height of camera images in pixels
+ camera_width: Width of camera images in pixels
+ render: Whether to render the simulation visually
+ n_steps_short: Number of simulation steps for quick movements
+ n_steps_long: Number of simulation steps for initial/slow movements
+ square: Whether to crop images to square aspect ratio
+ debug_cameras: Additional camera names for debugging views
+ epic: Whether to use Epic Kitchen coordinate system
+ joint_controller: Whether to use joint-level control instead of OSC
+ """
+ # Store configuration parameters
+ self.robot_name = robot_name
+ self.gripper_name = gripper_name
+ self.bimanual_setup = bimanual_setup
+ self.camera_params = camera_params
+ self.render = render
+ self.n_steps_long = n_steps_long
+ self.n_steps_short= n_steps_short
+ self.num_frames = 2 # Number of frames to keep in observation history
+ self.camera_height = camera_height
+ self.camera_width = camera_width
+ self.camera_name = "zed" # Main camera name
+ self.square = square
+ self.debug_cameras = list(debug_cameras) if debug_cameras else []
+ self.epic = epic # Epic Kitchen mode flag
+ self.joint_controller = joint_controller # Control mode flag
+
+ # Configure observation specifications for robomimic
+ obs_spec = dict(
+ obs=dict(
+ low_dim=["robot0_eef_pos"], # End-effector position observations
+ rgb=[f"{self.camera_params.name}_image"] + [f"{cam}_image" for cam in self.debug_cameras],
+ ),
+ )
+ ObsUtils.initialize_obs_utils_with_obs_specs(
+ obs_modality_specs=obs_spec)
+
+ # Configure robosuite environment options
+ options: dict[str, Union[str, list[str], dict[str, Any], bool, int, np.ndarray]] = {}
+ options["env_name"] = "PhantomBimanual"
+ options["bimanual_setup"] = bimanual_setup
+ options["robots"] = [self.robot_name, self.robot_name] # Two identical robots
+ if self.robot_name == "Kinova3":
+ options["gripper_types"] = [f"{self.gripper_name}GripperRealKinova", f"{self.gripper_name}GripperRealKinova"]
+ else:
+ options["gripper_types"] = [f"{self.gripper_name}Gripper", f"{self.gripper_name}Gripper"]
+
+ # Configure controller (OSC pose control by default)
+ controller_config = load_controller_config(default_controller="OSC_POSE")
+ controller_config["control_delta"] = False # Use absolute positioning
+ controller_config["uncouple_pos_ori"] = False # Couple position and orientation
+ options["controller_configs"] = controller_config
+
+ # Override with joint controller if specified
+ if self.joint_controller:
+ controller_config = load_controller_config(default_controller="JOINT_POSITION")
+ controller_config["input_type"] = "absolute"
+ controller_config["input_max"] = 10
+ controller_config["input_min"] = -10
+ controller_config["output_max"] = 10
+ controller_config["output_min"] = -10
+ controller_config["kd"] = 200 # Derivative gain
+ controller_config["kv"] = 200 # Velocity gain
+ controller_config["kp"] = 1000 # Proportional gain
+ controller_config["kp_limits"] = [0, 1000] # Proportional gain limits
+ options["controller_configs"] = controller_config
+
+ # Camera and observation settings
+ options["camera_heights"] = self.camera_height
+ options["camera_widths"] = self.camera_width
+ options["camera_segmentations"] = "instance" # Instance segmentation masks
+ options["direct_gripper_control"] = True
+ options["use_depth_obs"] = True
+
+ # Apply Epic Kitchen coordinate transformation if enabled
+ if self.epic:
+ self.base_T_1 = BASE_T_1
+ # Transform camera position and orientation to Epic Kitchen frame
+ self.camera_params.pos = self.base_T_1[:3, :3] @ self.camera_params.pos + self.base_T_1[:3, 3]
+ camera_ori_matrix = self.base_T_1[:3, :3] @ Rotation.from_quat(self.camera_params.ori_wxyz, scalar_first=True).as_matrix()
+ self.camera_params.ori_wxyz = Rotation.from_matrix(camera_ori_matrix).as_quat(scalar_first=True)
+
+ # Set camera parameters
+ options["camera_pos"] = self.camera_params.pos
+ options["camera_quat_wxyz"] = self.camera_params.ori_wxyz
+ options["camera_sensorsize"] = self.camera_params.sensorsize
+ options["camera_principalpixel"] = self.camera_params.principalpixel
+ options["camera_focalpixel"] = self.camera_params.focalpixel
+
+ # Create the robosuite environment
+ self.env = EnvRobosuite(
+ **options,
+ render=render,
+ render_offscreen=True, # Enable offscreen rendering for image capture
+ use_image_obs=True,
+ camera_names=[self.camera_params.name] + self.debug_cameras,
+ control_freq=20, # 20 Hz control frequency
+ )
+
+ # Initialize environment and compute robot base position
+ self.reset()
+ self.robot_base_pos = np.array([0, 0, self.env.env.robot_base_height+self.env.env.robot_base_offset])
+
+
+ def reset(self):
+ """Reset environment and clear observation history."""
+ self.env.reset()
+ self.obs_history = deque()
+
+ def close(self):
+ """Close the simulation environment."""
+ self.env.env.close()
+
+ def get_action_from_ee_pose(self, ee_pos: np.ndarray, ee_quat_xyzw: np.ndarray, gripper_action: float,
+ use_base_offset: bool = False) -> np.ndarray:
+ """
+ Convert end-effector pose to robot action vector.
+
+ This method transforms the desired end-effector position and orientation
+ into the action format expected by the robot controller.
+
+ Args:
+ ee_pos: End-effector position as 3D array
+ ee_quat_xyzw: End-effector orientation as quaternion (x, y, z, w)
+ gripper_action: Gripper action value
+ use_base_offset: Whether to add robot base offset to position
+
+ Returns:
+ Action vector [position(3), rotation(3), gripper(1)]
+ """
+ # Handle batch inputs by taking the last element
+ if ee_pos.ndim > 1:
+ ee_pos = ee_pos[-1]
+ ee_quat_xyzw = ee_quat_xyzw[-1]
+
+ # Add base offset if requested and not in Epic mode
+ if use_base_offset and not self.epic:
+ ee_pos = ee_pos + self.robot_base_pos
+
+ # Apply coordinate transformations based on mode
+ if self.epic:
+ # Transform position and orientation to Epic Kitchen coordinate frame
+ ee_pos = self.base_T_1[:3, 3] + self.base_T_1[:3, :3] @ ee_pos
+ axis_angle = Rotation.from_matrix(self.base_T_1[:3, :3] @ Rotation.from_quat(ee_quat_xyzw).as_matrix()).as_rotvec()
+ elif not self.epic:
+ # Apply 135-degree Z rotation for standard setup
+ rot = Rotation.from_quat(ee_quat_xyzw)
+ rot_135deg = Rotation.from_euler('z', 135, degrees=True)
+ new_rot = rot * rot_135deg
+ axis_angle = new_rot.as_rotvec()
+
+ # Combine into action vector
+ action = np.concatenate([ee_pos, axis_angle, [gripper_action]])
+
+ return action
+
+ def _get_initial_obs_history(self, state: dict) -> deque:
+ """
+ Initialize observation history by repeating the first observation.
+
+ This creates a history buffer filled with the initial robot state,
+ which is useful for algorithms that require temporal context.
+
+ Args:
+ state: Initial robot state dictionary
+
+ Returns:
+ Deque containing repeated initial observations
+ """
+ obs_history = deque(
+ [self.move_to_target_state(state, init=True)],
+ maxlen=self.num_frames,
+ )
+ # Fill remaining slots with copies of the initial observation
+ for _ in range(self.num_frames-1):
+ obs_history.append(self.move_to_target_state(state))
+ return obs_history
+
+ def get_obs_history(self, state: dict) -> list:
+ """
+ Get observation history with specified length.
+
+ Maintains a rolling buffer of recent observations for temporal context.
+
+ Args:
+ state: Current robot state dictionary
+
+ Returns:
+ List of recent observations (length = self.num_frames)
+ """
+ if len(self.obs_history) == 0:
+ # Initialize history if empty
+ self.obs_history = self._get_initial_obs_history(state)
+ else:
+ # Add new observation to history
+ self.obs_history.append(self.move_to_target_state(state))
+ return list(self.obs_history)
+
+ def move_to_target_state(self, state: dict, init=False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Move robot to target state and collect observation data.
+
+ This is the main method for controlling the robot and collecting observations.
+ It handles both pose and joint control modes, and collects RGB, depth,
+ and segmentation data along with tracking errors.
+
+ Args:
+ state: Target state containing positions, orientations, and gripper states
+ init: Whether this is an initialization step (uses longer movement time)
+
+ Returns:
+ Dictionary containing observation data:
+ - robot_mask: Binary mask showing robot pixels
+ - gripper_mask: Binary mask showing gripper pixels
+ - rgb_img: RGB camera image
+ - depth_img: Depth camera image
+ - robot_pos: Robot end-effector position
+ - left_pos_err: Left arm position tracking error
+ - right_pos_err: Right arm position tracking error
+ - {cam}_img: Additional camera images if debug_cameras specified
+ """
+ # Convert gripper positions to actions based on controller type
+ if not self.joint_controller:
+ # Use pose controller with gripper position mapping
+ gripper_action_0 = self._convert_handgripper_pos_to_action(state["gripper_pos"][0])
+ gripper_action_1 = self._convert_handgripper_pos_to_action(state["gripper_pos"][1])
+ gripper_action = [gripper_action_0, gripper_action_1]
+ else:
+ # Use joint controller with direct gripper control
+ gripper_action = [state["gripper_pos"][0]*255, state["gripper_pos"][1]*255]
+
+ # Choose movement duration based on whether this is initialization
+ n_steps = self.n_steps_long if init else self.n_steps_short
+
+ # Execute movement based on controller type
+ if not self.joint_controller:
+ # Move using pose control
+ obs = self.move_to_pose(state["pos"], state["ori_xyzw"], gripper_action, n_steps)
+ else:
+ # Move using joint control
+ obs = self.move_to_pose(state["pos"], state["ori_xyzw"], gripper_action, n_steps, state["q0"], state["q1"])
+
+ # Extract observation data from simulation
+ robot_mask = np.squeeze(self.get_robot_mask(obs))
+ gripper_mask = np.squeeze(self.get_gripper_mask(obs))
+ rgb_img = self.get_image(obs)
+ depth_img = self.get_depth_image(obs)
+ robot_pos = obs["robot0_eef_pos"] - self.robot_base_pos
+
+ # Calculate end-effector tracking errors for both arms
+ if not self.epic:
+ # Standard coordinate frame
+ right_pos_error = np.linalg.norm(obs['robot0_eef_pos']-self.robot_base_pos - state["pos"][0])
+ left_pos_error = np.linalg.norm(obs['robot1_eef_pos']-self.robot_base_pos - state["pos"][1])
+ else:
+ # Epic Kitchen coordinate frame
+ right_pos_error = np.linalg.norm(obs['robot0_eef_pos']-self.base_T_1[:3, 3] - self.base_T_1[:3, :3] @ state["pos"][0])
+ left_pos_error = np.linalg.norm(obs['robot1_eef_pos']-self.base_T_1[:3, 3] - self.base_T_1[:3, :3] @ state["pos"][1])
+
+ # Compile output dictionary
+ output = {
+ "robot_mask": robot_mask,
+ "gripper_mask": gripper_mask,
+ "rgb_img": rgb_img,
+ "depth_img": depth_img,
+ "robot_pos": robot_pos,
+ "left_pos_err": left_pos_error,
+ "right_pos_err": right_pos_error,
+ }
+
+ # Add debug camera images if specified
+ for cam in self.debug_cameras:
+ cam_img = self.get_camera_image(obs, cam)
+ output[f"{cam}_img"] = cam_img
+
+ return output
+
+ def _convert_handgripper_pos_to_action(self, gripper_pos: float) -> np.ndarray:
+ """
+ Convert hand gripper position to robot gripper action.
+
+ Maps from physical gripper opening distance to robot action values.
+ Different gripper types may have different mappings.
+
+ Args:
+ gripper_pos: Gripper opening distance in meters
+
+ Returns:
+ Robot gripper action value (0-255 for Robotiq85)
+
+ Raises:
+ ValueError: If gripper type is not supported
+ """
+ if self.gripper_name == "Robotiq85":
+ # Robotiq85 gripper specifications
+ min_gripper_pos, max_gripper_pos = 0.0, 0.085 # 0 to 8.5cm opening
+ gripper_pos = np.clip(gripper_pos, min_gripper_pos, max_gripper_pos)
+ open_gripper_action, closed_gripper_action = 0, 255 # 0=open, 255=closed
+ # Linear interpolation between open and closed states
+ return np.interp(gripper_pos, [min_gripper_pos, max_gripper_pos], [closed_gripper_action, open_gripper_action])
+ else:
+ raise ValueError(f"Gripper name {self.gripper_name} not supported")
+
+ def move_to_pose(self, ee_pos: dict, ee_ori: dict, gripper_action: dict, n_steps: int, q0=None, q1=None) -> dict:
+ """
+ Execute robot movement to target pose.
+
+ Sends action commands to the simulation for the specified number of steps.
+ Handles both pose control (OSC) and joint control modes.
+
+ Args:
+ ee_pos: End-effector positions for both arms {0: pos0, 1: pos1}
+ ee_ori: End-effector orientations for both arms {0: ori0, 1: ori1}
+ gripper_action: Gripper actions for both arms {0: grip0, 1: grip1}
+ n_steps: Number of simulation steps to execute
+ q0: Joint positions for arm 0 (only for joint controller)
+ q1: Joint positions for arm 1 (only for joint controller)
+
+ Returns:
+ Final observation dictionary from simulation
+ """
+ if not self.joint_controller:
+ # Pose control mode: convert poses to actions
+ action_0 = self.get_action_from_ee_pose(ee_pos[0], ee_ori[0], gripper_action[0], use_base_offset=True)
+ action_1 = self.get_action_from_ee_pose(ee_pos[1], ee_ori[1], gripper_action[1], use_base_offset=True)
+ action = np.concatenate([action_0, action_1])
+ else:
+ # Joint control mode: convert joint angles from degrees to radians
+ q0_new = []
+ for rot_q in q0:
+ if rot_q >= 180:
+ q0_new.append((rot_q/180*np.pi-2*np.pi)) # Handle angle wrapping
+ else:
+ q0_new.append(rot_q/180*np.pi)
+ q1_new = []
+ for rot_q in q1:
+ if rot_q >= 180:
+ q1_new.append((rot_q/180*np.pi-2*np.pi)) # Handle angle wrapping
+ else:
+ q1_new.append(rot_q/180*np.pi)
+
+ # Combine joint positions and gripper actions
+ action_0 = q0_new
+ action_1 = q1_new
+ action = np.concatenate([action_0, np.array(gripper_action[0]).reshape(1,), action_1, np.array(gripper_action[1]).reshape(1,)])
+
+ # Execute action for specified number of steps
+ for _ in range(n_steps):
+ obs, _, _, _ = self.env.step(action)
+ if self.render:
+ self.env.render()
+ return obs
+
+ def get_proprioception(self, obs: dict) -> np.ndarray:
+ """
+ Get proprioceptive information (robot's internal state).
+
+ Args:
+ obs: Observation dictionary from simulation
+
+ Returns:
+ End-effector position of first robot
+ """
+ pos = obs["robot0_eef_pos"]
+ return pos
+
+ def get_image(self, obs: dict) -> np.ndarray:
+ """
+ Extract RGB image from observation.
+
+ Handles image format conversion and optional square cropping.
+
+ Args:
+ obs: Observation dictionary containing image data
+
+ Returns:
+ RGB image as numpy array (H, W, 3)
+ """
+ img = obs[f"{self.camera_name}_image"]
+ img = img.transpose(1, 2, 0) # Convert from CHW to HWC format
+ height = img.shape[0]
+ width = img.shape[1]
+
+ # Crop to square if requested
+ if self.square:
+ n_remove = int((width - height)/2)
+ img = img[:,n_remove:-n_remove,:]
+ return img
+
+ def get_camera_image(self, obs: dict, camera_name: str) -> np.ndarray:
+ """
+ Extract RGB image from specific camera.
+
+ Args:
+ obs: Observation dictionary containing image data
+ camera_name: Name of the camera to extract image from
+
+ Returns:
+ RGB image as numpy array (H, W, 3)
+ """
+ img = obs[f"{camera_name}_image"]
+ img = img.transpose(1, 2, 0) # Convert from CHW to HWC format
+ height = img.shape[0]
+ width = img.shape[1]
+
+ # Crop to square if requested
+ if self.square:
+ n_remove = int((width - height)/2)
+ img = img[:,n_remove:-n_remove,:]
+ return img
+
+ def get_seg_image(self, obs: dict) -> np.ndarray:
+ """
+ Extract instance segmentation image.
+
+ Args:
+ obs: Observation dictionary containing segmentation data
+
+ Returns:
+ Segmentation image as uint8 array where each pixel value
+ represents a different object instance ID
+ """
+ img = obs[f"{self.camera_name}_segmentation_instance"]
+ height = img.shape[0]
+ width = img.shape[1]
+
+ # Crop to square if requested
+ if self.square:
+ n_remove = int((width - height)/2)
+ img = img[:,n_remove:-n_remove,:]
+ img = img.astype(np.uint8)
+ return img
+
+ def get_depth_image(self, obs: dict) -> np.ndarray:
+ """
+ Extract and process depth image.
+
+ Converts raw depth buffer to real-world depth values using
+ robosuite's depth processing utilities.
+
+ Args:
+ obs: Observation dictionary containing depth data
+
+ Returns:
+ Depth image as numpy array where values represent
+ distance in meters
+ """
+ img = obs[f"{self.camera_name}_depth"]
+ img = get_real_depth_map(sim=self.env.env.sim, depth_map=img)
+ height = img.shape[0]
+ width = img.shape[1]
+
+ # Crop to square if requested
+ if self.square:
+ n_remove = int((width - height)/2)
+ img = img[:,n_remove:-n_remove,:]
+ return img
+
+ def get_robot_mask(self, obs: dict) -> np.ndarray:
+ """
+ Generate binary mask for robot pixels.
+
+ Uses instance segmentation to identify which pixels belong to
+ the robot arms (instance IDs 1 and 4).
+
+ Args:
+ obs: Observation dictionary containing segmentation data
+
+ Returns:
+ Binary mask where 1 indicates robot pixels, 0 otherwise
+ """
+ seg_img = self.get_seg_image(obs)
+ mask = np.zeros_like(seg_img)
+ mask[seg_img == 1] = 1 # First robot arm
+ mask[seg_img == 4] = 1 # Second robot arm
+ return mask
+
+ def get_gripper_mask(self, obs: dict) -> np.ndarray:
+ """
+ Generate binary mask for gripper pixels.
+
+ Uses instance segmentation to identify which pixels belong to
+ the robot grippers (instance IDs 3 and 6).
+
+ Args:
+ obs: Observation dictionary containing segmentation data
+
+ Returns:
+ Binary mask where 1 indicates gripper pixels, 0 otherwise
+ """
+ seg_img = self.get_seg_image(obs)
+ mask = np.zeros_like(seg_img)
+ mask[seg_img == 3] = 1 # First gripper
+ mask[seg_img == 6] = 1 # Second gripper
+ return mask
diff --git a/phantom/phantom/twin_robot.py b/phantom/phantom/twin_robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..006539ace12d02a24354c21a73edcbfe447f39bf
--- /dev/null
+++ b/phantom/phantom/twin_robot.py
@@ -0,0 +1,490 @@
+"""
+Virtual twin single-arm robot implementation for MuJoCo simulation.
+
+This module provides a TwinRobot class that creates a virtual representation
+of a single-arm robot system in MuJoCo using the robosuite framework.
+The twin robot can be controlled via end-effector poses and provides
+observation data including RGB images, depth maps, and robot masks.
+"""
+
+from collections import deque
+import cv2
+import numpy as np
+from scipy.spatial.transform import Rotation
+from dataclasses import dataclass
+from typing import Tuple, Union, Any
+
+from robosuite.controllers import load_controller_config # type: ignore
+from robosuite.utils.camera_utils import get_real_depth_map # type: ignore
+from robomimic.envs.env_robosuite import EnvRobosuite # type: ignore
+import robomimic.utils.obs_utils as ObsUtils # type: ignore
+
+
+@dataclass
+class MujocoCameraParams:
+ """
+ Camera parameters for MuJoCo simulation.
+
+ Attributes:
+ name: Camera name identifier
+ pos: 3D position of camera in world coordinates
+ ori_wxyz: Camera orientation as quaternion (w, x, y, z)
+ fov: Field of view in degrees
+ resolution: Image resolution as (width, height)
+ sensorsize: Physical sensor size in mm
+ principalpixel: Principal point coordinates in pixels
+ focalpixel: Focal length in pixels
+ """
+ name: str
+ pos: np.ndarray
+ ori_wxyz: np.ndarray
+ fov: float
+ resolution: Tuple[int, int]
+ sensorsize: np.ndarray
+ principalpixel: np.ndarray
+ focalpixel: np.ndarray
+
+# Color constants for visualization (RGBA format)
+THUMB_COLOR = [0, 1, 0, 1] # Green for thumb
+INDEX_COLOR = [1, 0, 0, 1] # Red for index finger
+HAND_EE_COLOR = [0, 0, 1, 1] # Blue for hand end-effector
+
+def convert_real_camera_ori_to_mujoco(camera_ori_matrix: np.ndarray) -> np.ndarray:
+ """
+ Convert camera orientation from real world to MuJoCo XML format.
+
+ MuJoCo uses a different coordinate system convention, so we need to
+ flip the Y and Z axes of the rotation matrix before converting to quaternion.
+
+ Args:
+ camera_ori_matrix: 3x3 rotation matrix in real-world coordinates
+
+ Returns:
+ Camera orientation as quaternion in MuJoCo format (w, x, y, z)
+ """
+ camera_ori_matrix[:, [1, 2]] = -camera_ori_matrix[:, [1, 2]]
+ r = Rotation.from_matrix(camera_ori_matrix)
+ camera_ori_wxyz = r.as_quat(scalar_first=True)
+ return camera_ori_wxyz
+
+
+class TwinRobot:
+ """
+ Virtual twin of a single-arm robot system in MuJoCo simulation.
+
+ This class creates a simulated single-arm robot that can be controlled via
+ end-effector poses. It provides functionality for:
+ - Robot pose control using OSC (Operational Space Control)
+ - Camera observation collection (RGB, depth, segmentation)
+ - Robot and gripper mask generation
+ - Observation history management
+ """
+
+ # Robot configuration constants
+ DEFAULT_ROBOT_BASE_POS = np.array([-0.56, 0, 0.912])
+
+ def __init__(self, robot_name: str, gripper_name: str, camera_params: MujocoCameraParams, camera_height: int, camera_width: int,
+ render: bool, n_steps_short: int, n_steps_long: int, debug_cameras: list[str] = [],
+ square: bool = False):
+ """
+ Initialize the single-arm robot twin.
+
+ Args:
+ robot_name: Type of robot (e.g., "Kinova3")
+ gripper_name: Type of gripper (e.g., "Robotiq85")
+ camera_params: Camera configuration parameters
+ camera_height: Height of camera images in pixels
+ camera_width: Width of camera images in pixels
+ render: Whether to render the simulation visually
+ n_steps_short: Number of simulation steps for quick movements
+ n_steps_long: Number of simulation steps for initial/slow movements
+ debug_cameras: Additional camera names for debugging views
+ square: Whether to crop images to square aspect ratio
+ """
+ # Store configuration parameters
+ self.robot_name = robot_name
+ self.gripper_name = gripper_name
+ self.camera_params = camera_params
+ self.render = render
+ self.n_steps_long = n_steps_long
+ self.n_steps_short= n_steps_short
+ self.num_frames = 2 # Number of frames to keep in observation history
+ self.camera_height = camera_height
+ self.camera_width = camera_width
+ self.camera_name = "frontview" # Main camera name for single-arm setup
+ self.square = square
+ self.debug_cameras = list(debug_cameras) if debug_cameras else []
+
+ # Configure observation specifications for robomimic
+ obs_spec = dict(
+ obs=dict(
+ low_dim=["robot0_eef_pos"], # End-effector position observations
+ rgb=[f"{self.camera_params.name}_image"] + [f"{cam}_image" for cam in self.debug_cameras],
+ ),
+ )
+ ObsUtils.initialize_obs_utils_with_obs_specs(
+ obs_modality_specs=obs_spec)
+
+ # Configure robosuite environment options
+ options: dict[str, Union[str, list[str], dict[str, Any], bool, int, np.ndarray]] = {}
+ options["env_name"] = "Phantom" # Single-arm environment
+ options["robots"] = [self.robot_name] # Single robot
+ options["gripper_types"] = [f"{self.gripper_name}Gripper"] # Single gripper
+
+ # Configure OSC pose controller
+ controller_config = load_controller_config(default_controller="OSC_POSE")
+ controller_config["control_delta"] = False # Use absolute positioning
+ controller_config["uncouple_pos_ori"] = False # Couple position and orientation
+ options["controller_configs"] = controller_config
+
+ # Camera and observation settings
+ options["camera_heights"] = self.camera_height
+ options["camera_widths"] = self.camera_width
+ options["camera_segmentations"] = "instance" # Instance segmentation masks
+ options["direct_gripper_control"] = True
+ options["use_depth_obs"] = True
+
+ # Set camera parameters
+ options["camera_pos"] = self.camera_params.pos
+ options["camera_quat_wxyz"] = self.camera_params.ori_wxyz
+ options["camera_sensorsize"] = self.camera_params.sensorsize
+ options["camera_principalpixel"] = self.camera_params.principalpixel
+ options["camera_focalpixel"] = self.camera_params.focalpixel
+
+ # Create the robosuite environment
+ self.env = EnvRobosuite(
+ **options,
+ render=render,
+ render_offscreen=True, # Enable offscreen rendering for image capture
+ use_image_obs=True,
+ camera_names=[self.camera_params.name] + self.debug_cameras,
+ control_freq=20, # 20 Hz control frequency
+ )
+
+ # Initialize environment and set robot base position
+ self.reset()
+ self.robot_base_pos = self.DEFAULT_ROBOT_BASE_POS # Fixed base position for single-arm setup
+
+ def reset(self):
+ """Reset environment and clear observation history."""
+ self.env.reset()
+ self.obs_history = deque()
+
+ def close(self):
+ """Close the simulation environment."""
+ self.env.env.close()
+
+ def get_action_from_ee_pose(self, ee_pos: np.ndarray, ee_quat_xyzw: np.ndarray, gripper_action: float,
+ use_base_offset: bool = False) -> np.ndarray:
+ """
+ Convert end-effector pose to robot action vector.
+
+ This method transforms the desired end-effector position and orientation
+ into the action format expected by the robot controller.
+
+ Args:
+ ee_pos: End-effector position as 3D array
+ ee_quat_xyzw: End-effector orientation as quaternion (x, y, z, w)
+ gripper_action: Gripper action value
+ use_base_offset: Whether to add robot base offset to position
+
+ Returns:
+ Action vector [position(3), rotation(3), gripper(1)]
+ """
+ # Handle batch inputs by taking the last element
+ if ee_pos.ndim > 1:
+ ee_pos = ee_pos[-1]
+ ee_quat_xyzw = ee_quat_xyzw[-1]
+
+ # Add base offset if requested
+ if use_base_offset:
+ ee_pos = ee_pos + self.robot_base_pos
+
+ # Apply -135 degree Z rotation for single-arm setup coordinate conversion
+ rot = Rotation.from_quat(ee_quat_xyzw)
+ rot_135deg = Rotation.from_euler('z', -135, degrees=True)
+ new_rot = rot * rot_135deg
+
+ # Convert rotation to axis-angle representation
+ # Note: commented lines show alternative approach using quaternion directly
+ # quat_rotated = rot_rotated135.as_quat()
+ # axis_angle = Rotation.from_quat(quat_rotated).as_rotvec()
+ axis_angle = new_rot.as_rotvec()
+
+ # Combine position, rotation, and gripper action into action vector
+ action = np.concatenate([ee_pos, axis_angle, [gripper_action]])
+
+ return action
+
+ def _get_initial_obs_history(self, state: dict) -> deque:
+ """
+ Initialize observation history by repeating the first observation.
+
+ This creates a history buffer filled with the initial robot state,
+ which is useful for algorithms that require temporal context.
+
+ Args:
+ state: Initial robot state dictionary
+
+ Returns:
+ Deque containing repeated initial observations
+ """
+ obs_history = deque(
+ [self.move_to_target_state(state, init=True)],
+ maxlen=self.num_frames,
+ )
+ # Fill remaining slots with copies of the initial observation
+ for _ in range(self.num_frames-1):
+ obs_history.append(self.move_to_target_state(state))
+ return obs_history
+
+ def get_obs_history(self, state: dict) -> list:
+ """
+ Get observation history with specified length.
+
+ Maintains a rolling buffer of recent observations for temporal context.
+
+ Args:
+ state: Current robot state dictionary
+
+ Returns:
+ List of recent observations (length = self.num_frames)
+ """
+ if len(self.obs_history) == 0:
+ # Initialize history if empty
+ self.obs_history = self._get_initial_obs_history(state)
+ else:
+ # Add new observation to history
+ self.obs_history.append(self.move_to_target_state(state))
+ return list(self.obs_history)
+
+ def move_to_target_state(self, state: dict, init=False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Move robot to target state and collect observation data.
+
+ Args:
+ state: Target state containing position, orientation, and gripper state
+ init: Whether this is an initialization step (uses longer movement time)
+
+ Returns:
+ Dictionary containing observation data:
+ - robot_mask: Binary mask showing robot pixels
+ - gripper_mask: Binary mask showing gripper pixels
+ - rgb_img: RGB camera image
+ - depth_img: Depth camera image
+ - robot_pos: Robot end-effector position relative to base
+ - pos_err: Position tracking error magnitude
+ - {cam}_img: Additional camera images if debug_cameras specified
+ """
+ # Convert gripper position to robot action
+ gripper_action = self._convert_handgripper_pos_to_action(state["gripper_pos"])
+
+ # Choose movement duration based on whether this is initialization
+ n_steps = self.n_steps_long if init else self.n_steps_short
+
+ # Execute movement to target pose
+ obs = self.move_to_pose(state["pos"], state["ori_xyzw"], float(gripper_action), n_steps)
+
+ # Extract observation data from simulation
+ robot_mask = np.squeeze(self.get_robot_mask(obs))
+ gripper_mask = np.squeeze(self.get_gripper_mask(obs))
+ rgb_img = self.get_image(obs)
+ depth_img = self.get_depth_image(obs)
+ robot_pos = obs["robot0_eef_pos"] - self.robot_base_pos
+ pos_error = np.linalg.norm(robot_pos - state["pos"])
+
+ # Compile output dictionary
+ output = {
+ "robot_mask": robot_mask,
+ "gripper_mask": gripper_mask,
+ "rgb_img": rgb_img,
+ "depth_img": depth_img,
+ "robot_pos": robot_pos,
+ "pos_err": pos_error,
+ }
+
+ # Add debug camera images if specified
+ for cam in self.debug_cameras:
+ cam_img = self.get_cam_image(obs, cam)
+ output[f"{cam}_img"] = cam_img
+
+ return output
+
+ def _convert_handgripper_pos_to_action(self, gripper_pos: float) -> np.ndarray:
+ """
+ Convert hand gripper position to robot gripper action.
+
+ Maps from physical gripper opening distance to robot action values.
+ Different gripper types may have different mappings.
+
+ Args:
+ gripper_pos: Gripper opening distance in meters
+
+ Returns:
+ Robot gripper action value (0-255 for Robotiq85)
+
+ Raises:
+ ValueError: If gripper type is not supported
+ """
+ if self.gripper_name == "Robotiq85":
+ # Robotiq85 gripper specifications
+ min_gripper_pos, max_gripper_pos = 0.0, 0.085 # 0 to 8.5cm opening
+ gripper_pos = np.clip(gripper_pos, min_gripper_pos, max_gripper_pos)
+ open_gripper_action, closed_gripper_action = 0, 255 # 0=open, 255=closed
+ # Linear interpolation between open and closed states
+ return np.interp(gripper_pos, [min_gripper_pos, max_gripper_pos], [closed_gripper_action, open_gripper_action])
+ else:
+ raise ValueError(f"Gripper name {self.gripper_name} not supported")
+
+ def move_to_pose(self, ee_pos: np.ndarray, ee_ori: np.ndarray, gripper_action: float, n_steps: int) -> dict:
+ """
+ Execute robot movement to target pose.
+
+ Sends action commands to the simulation for the specified number of steps.
+
+ Args:
+ ee_pos: End-effector position as 3D array
+ ee_ori: End-effector orientation as quaternion (x, y, z, w)
+ gripper_action: Gripper action value
+ n_steps: Number of simulation steps to execute
+
+ Returns:
+ Final observation dictionary from simulation
+ """
+ # Convert pose to action vector
+ action = self.get_action_from_ee_pose(ee_pos, ee_ori, gripper_action, use_base_offset=True)
+
+ # Execute action for specified number of steps
+ for _ in range(n_steps):
+ obs, _, _, _ = self.env.step(action)
+ if self.render:
+ self.env.render()
+ return obs
+
+ def get_image(self, obs: dict) -> np.ndarray:
+ """
+ Extract RGB image from observation.
+
+ Handles image format conversion and optional square cropping.
+
+ Args:
+ obs: Observation dictionary containing image data
+
+ Returns:
+ RGB image as numpy array (H, W, 3)
+ """
+ img = obs[f"{self.camera_name}_image"]
+ img = img.transpose(1, 2, 0) # Convert from CHW to HWC format
+ height = img.shape[0]
+ width = img.shape[1]
+
+ # Crop to square if requested
+ if self.square:
+ n_remove = int((width - height)/2)
+ img = img[:,n_remove:-n_remove,:]
+ return img
+
+ def get_cam_image(self, obs: dict, camera_name: str) -> np.ndarray:
+ """
+ Extract RGB image from specific camera.
+
+ Args:
+ obs: Observation dictionary containing image data
+ camera_name: Name of the camera to extract image from
+
+ Returns:
+ RGB image as numpy array (H, W, 3)
+ """
+ img = obs[f"{camera_name}_image"]
+ img = img.transpose(1, 2, 0) # Convert from CHW to HWC format
+ height = img.shape[0]
+ width = img.shape[1]
+
+ # Crop to square if requested
+ if self.square:
+ n_remove = int((width - height)/2)
+ img = img[:,n_remove:-n_remove,:]
+ return img
+
+ def get_seg_image(self, obs: dict) -> np.ndarray:
+ """
+ Extract instance segmentation image.
+
+ Args:
+ obs: Observation dictionary containing segmentation data
+
+ Returns:
+ Segmentation image as uint8 array where each pixel value
+ represents a different object instance ID
+ """
+ img = obs["frontview_segmentation_instance"] # Fixed camera name for single-arm
+ height = img.shape[0]
+ width = img.shape[1]
+
+ # Crop to square if requested
+ if self.square:
+ n_remove = int((width - height)/2)
+ img = img[:,n_remove:-n_remove,:]
+ img = img.astype(np.uint8)
+ return img
+
+ def get_depth_image(self, obs: dict) -> np.ndarray:
+ """
+ Extract and process depth image.
+
+ Converts raw depth buffer to real-world depth values using
+ robosuite's depth processing utilities.
+
+ Args:
+ obs: Observation dictionary containing depth data
+
+ Returns:
+ Depth image as numpy array where values represent
+ distance in meters
+ """
+ img = obs["frontview_depth"] # Fixed camera name for single-arm
+ img = get_real_depth_map(sim=self.env.env.sim, depth_map=img)
+ height = img.shape[0]
+ width = img.shape[1]
+
+ # Crop to square if requested
+ if self.square:
+ n_remove = int((width - height)/2)
+ img = img[:,n_remove:-n_remove,:]
+ return img
+
+ def get_robot_mask(self, obs: dict) -> np.ndarray:
+ """
+ Generate binary mask for robot pixels.
+
+ Uses instance segmentation to identify which pixels belong to
+ the robot arm (instance ID 1).
+
+ Args:
+ obs: Observation dictionary containing segmentation data
+
+ Returns:
+ Binary mask where 1 indicates robot pixels, 0 otherwise
+ """
+ seg_img = self.get_seg_image(obs)
+ mask = np.zeros_like(seg_img)
+ mask[seg_img == 1] = 1 # Robot arm
+ return mask
+
+ def get_gripper_mask(self, obs: dict) -> np.ndarray:
+ """
+ Generate binary mask for gripper pixels.
+
+ Uses instance segmentation to identify which pixels belong to
+ the robot gripper (instance ID 3).
+
+ Args:
+ obs: Observation dictionary containing segmentation data
+
+ Returns:
+ Binary mask where 1 indicates gripper pixels, 0 otherwise
+ """
+ seg_img = self.get_seg_image(obs)
+ mask = np.zeros_like(seg_img)
+ mask[seg_img == 3] = 1 # Gripper
+ return mask
\ No newline at end of file
diff --git a/phantom/phantom/utils/__init__.py b/phantom/phantom/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/phantom/utils/bbox_utils.py b/phantom/phantom/utils/bbox_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b893cda53bb99e67279e70a08ef4557cd79e9fb4
--- /dev/null
+++ b/phantom/phantom/utils/bbox_utils.py
@@ -0,0 +1,38 @@
+import numpy as np
+import numpy.typing as npt
+
+def get_bbox_center(bbox: np.ndarray) -> np.ndarray:
+ """Calculate center point of bounding box."""
+ return np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2])
+
+
+def get_bbox_area(bbox: np.ndarray) -> float:
+ """Get the area of a bounding box."""
+ return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+
+
+def get_overlap_score(bbox1: np.ndarray, bbox2: np.ndarray) -> float:
+ """ Get the overlap area between two boxes divided by the area of the smaller box """
+ area1 = get_bbox_area(bbox1)
+ area2 = get_bbox_area(bbox2)
+ overlap_area = get_overlap_area(bbox1, bbox2)
+ return overlap_area / min(area1, area2)
+
+def get_overlap_area(bbox1: np.ndarray, bbox2: np.ndarray) -> float:
+ """ Get the overlap area between two boxes """
+ return max(0, min(bbox1[2], bbox2[2]) - max(bbox1[0], bbox2[0])) * max(0, min(bbox1[3], bbox2[3]) - max(bbox1[1], bbox2[1]))
+
+def get_bbox_center_min_dist_to_edge(bboxes: npt.NDArray[np.float32], W: int, H: int) -> npt.NDArray[np.float32]:
+ """
+ Get the minimum distance of the bbox center to the edge of the image.
+ """
+ center_min_dist_to_edge_list = []
+ for bbox in bboxes:
+ x1, y1, x2, y2 = bbox
+ center = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
+ min_dist_to_edge = min(center[0], center[1], W - center[0], H - center[1])
+ center_min_dist_to_edge_list.append(min_dist_to_edge)
+ return np.array(center_min_dist_to_edge_list)
+
+
+
diff --git a/phantom/phantom/utils/data_utils.py b/phantom/phantom/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..292dbef34ab5ef370de009f5f6f57a6d3ffb31de
--- /dev/null
+++ b/phantom/phantom/utils/data_utils.py
@@ -0,0 +1,38 @@
+import re
+import os
+import numpy as np
+import pandas as pd
+from pathlib import Path
+
+def get_finger_poses_from_pkl(path: Path) -> dict:
+ """Get human finger poses from pkl file."""
+ finger_poses = pd.read_pickle(path)
+ thumb_poses = np.vstack(finger_poses["thumb"])
+ index_poses = np.vstack(finger_poses["index"])
+ hand_ee_poses = np.vstack(finger_poses["hand_ee"])
+ skeleton_poses = np.stack(finger_poses["skeleton"], axis=0)
+ hand_poses = np.stack(finger_poses["hand_pose"], axis=0)
+ all_global_orient = np.vstack(finger_poses["global_orient"])
+ data = {
+ "thumb": thumb_poses,
+ "index": index_poses,
+ "hand_ee": hand_ee_poses,
+ "skeleton": skeleton_poses,
+ "hand_pose": hand_poses,
+ "global_orient": all_global_orient
+ }
+ return data
+
+def get_parent_folder_of_package(package_name: str) -> str:
+ # Import the package
+ package = __import__(package_name)
+
+ # Get the absolute path of the imported package
+ package_path = package.__file__
+ if package_path is None:
+ raise ValueError(f"Package {package_name} does not have a valid __file__ attribute")
+ package_path = os.path.abspath(package_path)
+
+ # Get the parent directory of the package directory
+ return os.path.dirname(os.path.dirname(package_path))
+
diff --git a/phantom/phantom/utils/image_utils.py b/phantom/phantom/utils/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a61d1d9d2f03d333dd6dc86885c7ce3510eec83
--- /dev/null
+++ b/phantom/phantom/utils/image_utils.py
@@ -0,0 +1,103 @@
+import json
+import numpy as np
+import cv2
+import os
+import mediapy as media
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple
+
+@dataclass
+class BoundingBox:
+ xmin: int
+ ymin: int
+ xmax: int
+ ymax: int
+
+ @property
+ def xyxy(self) -> List[float]:
+ return [self.xmin, self.ymin, self.xmax, self.ymax]
+
+
+@dataclass
+class DetectionResult:
+ score: float
+ label: str
+ box: BoundingBox
+ mask: Optional[np.ndarray] = None
+
+ @classmethod
+ def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
+ return cls(
+ score=detection_dict["score"],
+ label=detection_dict["label"],
+ box=BoundingBox(
+ xmin=detection_dict["box"]["xmin"],
+ ymin=detection_dict["box"]["ymin"],
+ xmax=detection_dict["box"]["xmax"],
+ ymax=detection_dict["box"]["ymax"],
+ ),
+ )
+
+def get_transformation_matrix_from_extrinsics(camera_extrinsics: List[Dict]) -> np.ndarray:
+ """Get homogeneous transformation matrix from camera extrinsics."""
+ cam_base_pos = np.array(camera_extrinsics[0]["camera_base_pos"])
+ cam_base_ori = np.array(camera_extrinsics[0]["camera_base_ori"])
+ T_cam2robot = np.eye(4)
+ T_cam2robot[:3, 3] = cam_base_pos
+ T_cam2robot[:3, :3] = np.array(cam_base_ori).reshape(3, 3)
+ return T_cam2robot
+
+
+def get_intrinsics_from_json(json_path: str) -> Tuple[np.ndarray, dict]:
+ with open(json_path, "r") as f:
+ camera_intrinsics = json.load(f)
+
+ # Get camera matrix
+ fx = camera_intrinsics["left"]["fx"]
+ fy = camera_intrinsics["left"]["fy"]
+ cx = camera_intrinsics["left"]["cx"]
+ cy = camera_intrinsics["left"]["cy"]
+ v_fov = camera_intrinsics["left"]["v_fov"]
+ intrinsics_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
+
+ intrinsics_dict = {
+ "fx": fx,
+ "fy": fy,
+ "cx": cx,
+ "cy": cy,
+ "v_fov": v_fov,
+ }
+
+ return intrinsics_matrix, intrinsics_dict
+
+def resize_binary_image(image: np.ndarray, new_size: int) -> np.ndarray:
+ max_value = np.max(image)
+
+ # Resize the image
+ resized_image = cv2.resize(image, (new_size, new_size), interpolation=cv2.INTER_NEAREST)
+
+ if max_value == 1:
+ _, binary_image = cv2.threshold(resized_image, 0.5, 1, cv2.THRESH_BINARY)
+ else:
+ _, binary_image = cv2.threshold(resized_image, 127, 255, cv2.THRESH_BINARY)
+
+ return binary_image
+
+
+def convert_video_to_images(video_path: str, save_folder: str, square=False, reverse=False):
+ """Save each frame of video as an image in save_folder."""
+ if not os.path.exists(save_folder):
+ os.makedirs(save_folder)
+
+ imgs = np.array(media.read_video(str(video_path)))
+ n_imgs = len(imgs)
+ if reverse:
+ imgs = imgs[::-1]
+ for idx in range(n_imgs):
+ img = imgs[idx]
+ if square:
+ delta = (img.shape[1] - img.shape[0]) // 2
+ img = img[:, delta:-delta, :]
+ media.write_image(f"{save_folder}/{idx:05d}.jpg", img)
+
+
diff --git a/phantom/phantom/utils/pcd_utils.py b/phantom/phantom/utils/pcd_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0bb7677653781ce386cd14f4e0384a6d514a952
--- /dev/null
+++ b/phantom/phantom/utils/pcd_utils.py
@@ -0,0 +1,210 @@
+import numpy as np
+from typing import Tuple, Optional
+import open3d as o3d # type: ignore
+import trimesh
+from sklearn.neighbors import NearestNeighbors # type: ignore
+
+def preprocess_point_cloud(pcd: o3d.geometry.PointCloud,
+ voxel_size: float) -> Tuple[o3d.geometry.PointCloud, o3d.pipelines.registration.Feature]:
+ """
+ Downsample point cloud to desired voxel resolution and compute FPFH features.
+ """
+ pcd_down = pcd.voxel_down_sample(voxel_size)
+ radius_normal = voxel_size * 2
+ pcd_down.estimate_normals(o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30))
+ radius_feature = voxel_size * 5
+ pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
+ pcd_down, o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100))
+ return pcd_down, pcd_fpfh
+
+
+def global_registration(source_pcd: o3d.geometry.PointCloud, target_pcd: o3d.geometry.PointCloud,
+ voxel_size: float) -> o3d.pipelines.registration.RegistrationResult:
+ """
+ Register two point clouds using global registration with RANSAC.
+ """
+ source_down, source_fpfh = preprocess_point_cloud(source_pcd, voxel_size)
+ target_down, target_fpfh = preprocess_point_cloud(target_pcd, voxel_size)
+
+ distance_threshold = voxel_size * 1.5
+ result_ransac = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
+ source_down, target_down, source_fpfh, target_fpfh, True,
+ distance_threshold,
+ o3d.pipelines.registration.TransformationEstimationPointToPoint(),
+ 4, # RANSAC iterations
+ [o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
+ o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)],
+ o3d.pipelines.registration.RANSACConvergenceCriteria(4000000, 500))
+
+ return result_ransac
+
+
+def icp_registration(source_pcd: o3d.geometry.PointCloud, target_pcd: o3d.geometry.PointCloud,
+ voxel_size: float=0.05, use_global_registration:bool=True,
+ init_transform:Optional[np.ndarray]=None) -> Tuple[o3d.geometry.PointCloud, np.ndarray]:
+ """
+ Register two point clouds using ICP algorithm.
+ """
+ # Optional global registration using RANSAC
+ if use_global_registration:
+ if init_transform is None:
+ result_ransac = global_registration(source_pcd, target_pcd, voxel_size)
+ init_transform = result_ransac.transformation
+ else:
+ init_transform = np.eye(4)
+
+ # Refine alignment using ICP
+ max_correspondence_distance = voxel_size * 5
+ result_icp = o3d.pipelines.registration.registration_icp(
+ source=source_pcd, target=target_pcd, max_correspondence_distance=max_correspondence_distance,
+ init=init_transform,
+ estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint())
+
+ if np.array_equal(init_transform, result_icp.transformation):
+ result_ransac = global_registration(source_pcd, target_pcd, voxel_size)
+ init_transform = result_ransac.transformation
+ result_icp = o3d.pipelines.registration.registration_icp(
+ source=source_pcd, target=target_pcd, max_correspondence_distance=max_correspondence_distance,
+ init=init_transform,
+ estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint())
+
+ aligned_source_pcd = source_pcd.transform(result_icp.transformation)
+
+ return aligned_source_pcd, result_icp.transformation
+
+
+def get_visible_points(mesh, origin: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Return list of points in mesh that are visible from origin.
+ """
+ intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh)
+ pts = mesh.vertices
+ vectors = pts - origin
+ directions = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
+ visible_triangle_indices = intersector.intersects_first(np.tile(origin, (pts.shape[0], 1)), directions)
+ visible_triangles = mesh.faces[visible_triangle_indices]
+ visible_vertex_indices = np.unique(visible_triangles)
+ visible_points = pts[visible_vertex_indices]
+ return np.array(visible_points).astype(np.float32), np.array(visible_vertex_indices)
+
+
+def get_pcd_from_points(points: np.ndarray, colors: Optional[np.ndarray]=None) -> o3d.geometry.PointCloud:
+ """
+ Convert a list of points to an Open3D point cloud.
+ """
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ if colors is not None:
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+ pcd.remove_non_finite_points()
+ return pcd
+
+
+def visualize_pcds(list_pcds: list, visible: bool=True) -> np.ndarray:
+ """
+ Visualize a list of point clouds.
+ """
+ visualization_image = None
+ vis = o3d.visualization.Visualizer()
+ vis.create_window(visible=visible)
+ opt = vis.get_render_option()
+ opt.background_color = np.asarray([0.2, 0.2, 0.2])
+ for pcd in list_pcds:
+ if pcd is not None:
+ vis.add_geometry(pcd)
+ vis.poll_events()
+ vis.update_renderer()
+ if not visible:
+ visualization_image = vis.capture_screen_float_buffer(do_render=True)
+ visualization_image = (255.0 * np.asarray(visualization_image)).astype(np.uint8)
+ if visible:
+ vis.run()
+ vis.destroy_window()
+ if visualization_image is None:
+ visualization_image = np.array([])
+ return visualization_image
+
+def radius_outlier_detection(points: np.ndarray, radius: float=5,
+ min_neighbors: int=5) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Detect outliers in a point cloud using radius-based outlier detection.
+ """
+ # Fit the NearestNeighbors model
+ nbrs = NearestNeighbors(radius=radius).fit(points)
+
+ # Get the number of neighbors for each point within the specified radius
+ distances, indices = nbrs.radius_neighbors(points)
+
+ # Detect points with fewer neighbors than the minimum threshold
+ outliers_mask = np.array([len(neigh) < min_neighbors for neigh in indices])
+
+ outlier_pts = points[outliers_mask]
+
+ return outliers_mask, outlier_pts
+
+
+def remove_outliers(pcd: o3d.geometry.PointCloud, radius: float=5,
+ min_neighbors: int=5) -> Tuple[o3d.geometry.PointCloud, np.ndarray]:
+ """
+ Remove outliers from a point cloud using radius-based outlier detection.
+ """
+ outlier_indices, outlier_pts = radius_outlier_detection(np.asarray(pcd.points),
+ radius=radius, min_neighbors=min_neighbors)
+ filtered_pts = np.asarray(pcd.points)[~outlier_indices]
+ filtered_colors = np.asarray(pcd.colors)[~outlier_indices]
+ filtered_pcd = get_pcd_from_points(filtered_pts, colors=filtered_colors)
+ return filtered_pcd, outlier_indices
+
+def get_3D_points_from_pixels(pixels_2d: np.ndarray, depth_map: np.ndarray, intrinsics: dict) -> np.ndarray:
+ """
+ Convert an array of pixel coordinates and depth map to 3D points.
+ """
+ px = pixels_2d[:, 0]
+ py = pixels_2d[:, 1]
+
+ x = (px - intrinsics["cx"]) / intrinsics["fx"]
+ y = (py - intrinsics["cy"]) / intrinsics["fy"]
+
+ if len(depth_map.shape) == 3:
+ depth_map = depth_map[:, :, 0]
+
+ depth = depth_map[py, px]
+
+ X = x * depth
+ Y = y * depth
+
+ points_3d = np.stack((X, Y, depth), axis=1)
+ return points_3d
+
+def get_point_cloud_of_segmask(mask: np.ndarray, depth_img: np.ndarray, img: np.ndarray,
+ intrinsics: dict, visualize: bool=False) -> o3d.geometry.PointCloud:
+ """
+ Return the point cloud that corresponds to the segmentation mask in the depth image.
+ """
+ idxs_y, idxs_x = mask.nonzero()
+ pixels_2d = np.stack((idxs_x, idxs_y), axis=1)
+ seg_points = get_3D_points_from_pixels(pixels_2d, depth_img, intrinsics)
+ seg_colors = img[idxs_y, idxs_x, :] / 255.0 # Normalize to [0,1] for cv2
+
+ pcd = get_pcd_from_points(seg_points, colors=seg_colors)
+
+ if visualize:
+ visualize_pcds([pcd])
+
+ return pcd
+
+def get_bbox_of_3d_points(points: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Return the bounding box of 3D points.
+ """
+ min_xyz = np.min(points, axis=0)
+ max_xyz = np.max(points, axis=0)
+ return min_xyz, max_xyz
+
+def trim_pcd_to_bbox(pcd: o3d.geometry.PointCloud, bbox: Tuple[np.ndarray, np.ndarray]) -> o3d.geometry.PointCloud:
+ """
+ Trim a point cloud to the specified bounding box.
+ """
+ min_xyz, max_xyz = bbox
+ trimmed_pcd = pcd.crop(o3d.geometry.AxisAlignedBoundingBox(min_xyz, max_xyz))
+ return trimmed_pcd
\ No newline at end of file
diff --git a/phantom/phantom/utils/transform_utils.py b/phantom/phantom/utils/transform_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbca6ba296d99e50f9a0484d6205a7f60396d582
--- /dev/null
+++ b/phantom/phantom/utils/transform_utils.py
@@ -0,0 +1,43 @@
+import numpy as np
+import math
+
+EPS = np.finfo(float).eps * 4.0
+
+def transform_pts(pts: np.ndarray, T: np.ndarray) -> np.ndarray:
+ pts = np.hstack([pts, np.ones((len(pts), 1))])
+ pts = np.dot(T, pts.T).T
+ return pts[:, :3]
+
+def project_point_to_plane(point: np.ndarray, plane_coeffs: np.ndarray) -> np.ndarray:
+ """
+ Projects a 3D point onto a plane defined by its coefficients.
+
+ Args:
+ point (array-like): Coordinates of the point to be projected (x0, y0, z0).
+ plane_coeffs (array-like): Coefficients of the plane (a, b, c, d) for ax + by + cz + d = 0.
+
+ Returns:
+ numpy.ndarray: The projected point's coordinates on the plane.
+ """
+ # Convert inputs to numpy arrays
+ point = np.array(point)
+ plane_coeffs = np.array(plane_coeffs)
+
+ # Extract the plane normal vector and constant term
+ normal = plane_coeffs[:3] # [a, b, c]
+ d = plane_coeffs[3]
+
+ # Normalize the plane normal vector
+ normal_magnitude = np.linalg.norm(normal)
+ if normal_magnitude == 0:
+ raise ValueError("Invalid plane coefficients: normal vector cannot have zero magnitude.")
+ normal /= normal_magnitude
+
+ # Calculate the signed distance from the point to the plane
+ distance = np.dot(normal, point) + d / normal_magnitude
+
+ # Project the point onto the plane
+ projected_point = point - distance * normal
+
+ return projected_point
+
diff --git a/phantom/setup.py b/phantom/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a9fec23ac0eb103e3f849c66355f5bce2f995f4
--- /dev/null
+++ b/phantom/setup.py
@@ -0,0 +1,7 @@
+import setuptools
+
+setuptools.setup(
+ name="phantom",
+ version="0.1",
+ packages=setuptools.find_packages(exclude=["submodules", "submodules.*"]),
+)
\ No newline at end of file
diff --git a/phantom/submodules/phantom-E2FGVI/.gitignore b/phantom/submodules/phantom-E2FGVI/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..de2ae7cce9460e73fbd78398b0e401a3bc7b861f
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/.gitignore
@@ -0,0 +1,136 @@
+# Customized
+*.pth
+*.pt
+keys.txt
+results/
+.vscode/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/__init__.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi.json b/phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi.json
new file mode 100644
index 0000000000000000000000000000000000000000..2093a0deb42da5dd2c8e60f63cfb458ccc6852c2
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi.json
@@ -0,0 +1,41 @@
+{
+ "seed": 2021,
+ "save_dir": "release_model/",
+ "train_data_loader": {
+ "name": "youtube-vos",
+ "data_root": "datasets",
+ "w": 432,
+ "h": 240,
+ "num_local_frames": 5,
+ "num_ref_frames": 3
+ },
+ "losses": {
+ "hole_weight": 1,
+ "valid_weight": 1,
+ "flow_weight": 1,
+ "adversarial_weight": 0.01,
+ "GAN_LOSS": "hinge"
+ },
+ "model": {
+ "net": "e2fgvi",
+ "no_dis": 0
+ },
+ "trainer": {
+ "type": "Adam",
+ "beta1": 0,
+ "beta2": 0.99,
+ "lr": 1e-4,
+ "batch_size": 8,
+ "num_workers": 2,
+ "log_freq": 100,
+ "save_freq": 5e3,
+ "iterations": 50e4,
+ "scheduler": {
+ "type": "MultiStepLR",
+ "milestones": [
+ 40e4
+ ],
+ "gamma": 0.1
+ }
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi_hq.json b/phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi_hq.json
new file mode 100644
index 0000000000000000000000000000000000000000..6693b731cc62e354e2c27342d9e5a2807e0c0a4a
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi_hq.json
@@ -0,0 +1,41 @@
+{
+ "seed": 2021,
+ "save_dir": "release_model/",
+ "train_data_loader": {
+ "name": "youtube-vos",
+ "data_root": "datasets",
+ "w": 432,
+ "h": 240,
+ "num_local_frames": 5,
+ "num_ref_frames": 3
+ },
+ "losses": {
+ "hole_weight": 1,
+ "valid_weight": 1,
+ "flow_weight": 1,
+ "adversarial_weight": 0.01,
+ "GAN_LOSS": "hinge"
+ },
+ "model": {
+ "net": "e2fgvi_hq",
+ "no_dis": 0
+ },
+ "trainer": {
+ "type": "Adam",
+ "beta1": 0,
+ "beta2": 0.99,
+ "lr": 1e-4,
+ "batch_size": 8,
+ "num_workers": 2,
+ "log_freq": 100,
+ "save_freq": 5e3,
+ "iterations": 50e4,
+ "scheduler": {
+ "type": "MultiStepLR",
+ "milestones": [
+ 40e4
+ ],
+ "gamma": 0.1
+ }
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/core/__init__.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/core/dataset.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5d7f992c73cc3b32be64e335caf81f236cb0242
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/dataset.py
@@ -0,0 +1,135 @@
+import os
+import json
+import random
+
+import cv2
+from PIL import Image
+import numpy as np
+
+import torch
+import torchvision.transforms as transforms
+
+from core.utils import (TrainZipReader, TestZipReader,
+ create_random_shape_with_random_motion, Stack,
+ ToTorchFormatTensor, GroupRandomHorizontalFlip)
+
+
+class TrainDataset(torch.utils.data.Dataset):
+ def __init__(self, args: dict, debug=False):
+ self.args = args
+ self.num_local_frames = args['num_local_frames']
+ self.num_ref_frames = args['num_ref_frames']
+ self.size = self.w, self.h = (args['w'], args['h'])
+
+ json_path = os.path.join(args['data_root'], args['name'], 'train.json')
+ with open(json_path, 'r') as f:
+ self.video_dict = json.load(f)
+ self.video_names = list(self.video_dict.keys())
+ if debug:
+ self.video_names = self.video_names[:100]
+
+ self._to_tensors = transforms.Compose([
+ Stack(),
+ ToTorchFormatTensor(),
+ ])
+
+ def __len__(self):
+ return len(self.video_names)
+
+ def __getitem__(self, index):
+ item = self.load_item(index)
+ return item
+
+ def _sample_index(self, length, sample_length, num_ref_frame=3):
+ complete_idx_set = list(range(length))
+ pivot = random.randint(0, length - sample_length)
+ local_idx = complete_idx_set[pivot:pivot + sample_length]
+ remain_idx = list(set(complete_idx_set) - set(local_idx))
+ ref_index = sorted(random.sample(remain_idx, num_ref_frame))
+
+ return local_idx + ref_index
+
+ def load_item(self, index):
+ video_name = self.video_names[index]
+ # create masks
+ all_masks = create_random_shape_with_random_motion(
+ self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w)
+
+ # create sample index
+ selected_index = self._sample_index(self.video_dict[video_name],
+ self.num_local_frames,
+ self.num_ref_frames)
+
+ # read video frames
+ frames = []
+ masks = []
+ for idx in selected_index:
+ video_path = os.path.join(self.args['data_root'],
+ self.args['name'], 'JPEGImages',
+ f'{video_name}.zip')
+ img = TrainZipReader.imread(video_path, idx).convert('RGB')
+ img = img.resize(self.size)
+ frames.append(img)
+ masks.append(all_masks[idx])
+
+ # normalizate, to tensors
+ frames = GroupRandomHorizontalFlip()(frames)
+ frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
+ mask_tensors = self._to_tensors(masks)
+ return frame_tensors, mask_tensors, video_name
+
+
+class TestDataset(torch.utils.data.Dataset):
+ def __init__(self, args):
+ self.args = args
+ self.size = self.w, self.h = args.size
+
+ with open(os.path.join(args.data_root, args.dataset, 'test.json'),
+ 'r') as f:
+ self.video_dict = json.load(f)
+ self.video_names = list(self.video_dict.keys())
+
+ self._to_tensors = transforms.Compose([
+ Stack(),
+ ToTorchFormatTensor(),
+ ])
+
+ def __len__(self):
+ return len(self.video_names)
+
+ def __getitem__(self, index):
+ item = self.load_item(index)
+ return item
+
+ def load_item(self, index):
+ video_name = self.video_names[index]
+ ref_index = list(range(self.video_dict[video_name]))
+
+ # read video frames
+ frames = []
+ masks = []
+ for idx in ref_index:
+ video_path = os.path.join(self.args.data_root, self.args.dataset,
+ 'JPEGImages', f'{video_name}.zip')
+ img = TestZipReader.imread(video_path, idx).convert('RGB')
+ img = img.resize(self.size)
+ frames.append(img)
+ mask_path = os.path.join(self.args.data_root, self.args.dataset,
+ 'test_masks', video_name,
+ str(idx).zfill(5) + '.png')
+ mask = Image.open(mask_path).resize(self.size,
+ Image.NEAREST).convert('L')
+ # origin: 0 indicates missing. now: 1 indicates missing
+ mask = np.asarray(mask)
+ m = np.array(mask > 0).astype(np.uint8)
+ m = cv2.dilate(m,
+ cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
+ iterations=4)
+ mask = Image.fromarray(m * 255)
+ masks.append(mask)
+
+ # to tensors
+ frames_PIL = [np.array(f).astype(np.uint8) for f in frames]
+ frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
+ mask_tensors = self._to_tensors(masks)
+ return frame_tensors, mask_tensors, video_name, frames_PIL
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/core/dist.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e4e9e670a3b853fac345618d3557d648d813902
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/dist.py
@@ -0,0 +1,47 @@
+import os
+import torch
+
+
+def get_world_size():
+ """Find OMPI world size without calling mpi functions
+ :rtype: int
+ """
+ if os.environ.get('PMI_SIZE') is not None:
+ return int(os.environ.get('PMI_SIZE') or 1)
+ elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
+ return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
+ else:
+ return torch.cuda.device_count()
+
+
+def get_global_rank():
+ """Find OMPI world rank without calling mpi functions
+ :rtype: int
+ """
+ if os.environ.get('PMI_RANK') is not None:
+ return int(os.environ.get('PMI_RANK') or 0)
+ elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
+ return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
+ else:
+ return 0
+
+
+def get_local_rank():
+ """Find OMPI local rank without calling mpi functions
+ :rtype: int
+ """
+ if os.environ.get('MPI_LOCALRANKID') is not None:
+ return int(os.environ.get('MPI_LOCALRANKID') or 0)
+ elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
+ return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
+ else:
+ return 0
+
+
+def get_master_ip():
+ if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
+ return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
+ elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
+ return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
+ else:
+ return "127.0.0.1"
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/core/loss.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..19d0d5f4e3118d82a844921a99b5aa66f05bb7d6
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/loss.py
@@ -0,0 +1,41 @@
+import torch
+import torch.nn as nn
+
+
+class AdversarialLoss(nn.Module):
+ r"""
+ Adversarial loss
+ https://arxiv.org/abs/1711.10337
+ """
+ def __init__(self,
+ type='nsgan',
+ target_real_label=1.0,
+ target_fake_label=0.0):
+ r"""
+ type = nsgan | lsgan | hinge
+ """
+ super(AdversarialLoss, self).__init__()
+ self.type = type
+ self.register_buffer('real_label', torch.tensor(target_real_label))
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
+
+ if type == 'nsgan':
+ self.criterion = nn.BCELoss()
+ elif type == 'lsgan':
+ self.criterion = nn.MSELoss()
+ elif type == 'hinge':
+ self.criterion = nn.ReLU()
+
+ def __call__(self, outputs, is_real, is_disc=None):
+ if self.type == 'hinge':
+ if is_disc:
+ if is_real:
+ outputs = -outputs
+ return self.criterion(1 + outputs).mean()
+ else:
+ return (-outputs).mean()
+ else:
+ labels = (self.real_label
+ if is_real else self.fake_label).expand_as(outputs)
+ loss = self.criterion(outputs, labels)
+ return loss
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/core/lr_scheduler.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd1341cdcc64aa1c2a416b837551590ded4a43d
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/lr_scheduler.py
@@ -0,0 +1,112 @@
+"""
+ LR scheduler from BasicSR https://github.com/xinntao/BasicSR
+"""
+import math
+from collections import Counter
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class MultiStepRestartLR(_LRScheduler):
+ """ MultiStep with restarts learning rate scheme.
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ milestones (list): Iterations that will decrease learning rate.
+ gamma (float): Decrease ratio. Default: 0.1.
+ restarts (list): Restart iterations. Default: [0].
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+ def __init__(self,
+ optimizer,
+ milestones,
+ gamma=0.1,
+ restarts=(0, ),
+ restart_weights=(1, ),
+ last_epoch=-1):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ self.restarts = restarts
+ self.restart_weights = restart_weights
+ assert len(self.restarts) == len(
+ self.restart_weights), 'restarts and their weights do not match.'
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch in self.restarts:
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
+ return [
+ group['initial_lr'] * weight
+ for group in self.optimizer.param_groups
+ ]
+ if self.last_epoch not in self.milestones:
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [
+ group['lr'] * self.gamma**self.milestones[self.last_epoch]
+ for group in self.optimizer.param_groups
+ ]
+
+
+def get_position_from_periods(iteration, cumulative_period):
+ """Get the position from a period list.
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_period = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 2.
+ Args:
+ iteration (int): Current iteration.
+ cumulative_period (list[int]): Cumulative period list.
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_period):
+ if iteration <= period:
+ return i
+
+
+class CosineAnnealingRestartLR(_LRScheduler):
+ """ Cosine annealing with restarts learning rate scheme.
+ An example of config:
+ periods = [10, 10, 10, 10]
+ restart_weights = [1, 0.5, 0.5, 0.5]
+ eta_min=1e-7
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
+ scheduler will restart with the weights in restart_weights.
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ periods (list): Period for each cosine anneling cycle.
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ eta_min (float): The mimimum lr. Default: 0.
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+ def __init__(self,
+ optimizer,
+ periods,
+ restart_weights=(1, ),
+ eta_min=1e-7,
+ last_epoch=-1):
+ self.periods = periods
+ self.restart_weights = restart_weights
+ self.eta_min = eta_min
+ assert (len(self.periods) == len(self.restart_weights)
+ ), 'periods and restart_weights should have the same length.'
+ self.cumulative_period = [
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
+ ]
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ idx = get_position_from_periods(self.last_epoch,
+ self.cumulative_period)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
+ current_period = self.periods[idx]
+
+ return [
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
+ (1 + math.cos(math.pi * (
+ (self.last_epoch - nearest_restart) / current_period)))
+ for base_lr in self.base_lrs
+ ]
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/core/metrics.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..441613d8e96983b4dc72ca046a16790011b23e2a
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/metrics.py
@@ -0,0 +1,570 @@
+import numpy as np
+from skimage import measure
+from scipy import linalg
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from core.utils import to_tensors
+
+
+def calculate_epe(flow1, flow2):
+ """Calculate End point errors."""
+
+ epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt()
+ epe = epe.view(-1)
+ return epe.mean().item()
+
+
+def calculate_psnr(img1, img2):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ Returns:
+ float: psnr result.
+ """
+
+ assert img1.shape == img2.shape, \
+ (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20. * np.log10(255. / np.sqrt(mse))
+
+
+def calc_psnr_and_ssim(img1, img2):
+ """Calculate PSNR and SSIM for images.
+ img1: ndarray, range [0, 255]
+ img2: ndarray, range [0, 255]
+ """
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ psnr = calculate_psnr(img1, img2)
+ ssim = measure.compare_ssim(img1,
+ img2,
+ data_range=255,
+ multichannel=True,
+ win_size=65)
+
+ return psnr, ssim
+
+
+###########################
+# I3D models
+###########################
+
+
+def init_i3d_model():
+ i3d_model_path = './release_model/i3d_rgb_imagenet.pt'
+ print(f"[Loading I3D model from {i3d_model_path} for FID score ..]")
+ i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits')
+ i3d_model.load_state_dict(torch.load(i3d_model_path))
+ i3d_model.to(torch.device('cuda:0'))
+ return i3d_model
+
+
+def calculate_i3d_activations(video1, video2, i3d_model, device):
+ """Calculate VFID metric.
+ video1: list[PIL.Image]
+ video2: list[PIL.Image]
+ """
+ video1 = to_tensors()(video1).unsqueeze(0).to(device)
+ video2 = to_tensors()(video2).unsqueeze(0).to(device)
+ video1_activations = get_i3d_activations(
+ video1, i3d_model).cpu().numpy().flatten()
+ video2_activations = get_i3d_activations(
+ video2, i3d_model).cpu().numpy().flatten()
+
+ return video1_activations, video2_activations
+
+
+def calculate_vfid(real_activations, fake_activations):
+ """
+ Given two distribution of features, compute the FID score between them
+ Params:
+ real_activations: list[ndarray]
+ fake_activations: list[ndarray]
+ """
+ m1 = np.mean(real_activations, axis=0)
+ m2 = np.mean(fake_activations, axis=0)
+ s1 = np.cov(real_activations, rowvar=False)
+ s2 = np.cov(fake_activations, rowvar=False)
+ return calculate_frechet_distance(m1, s1, m2, s2)
+
+
+def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+ """Numpy implementation of the Frechet Distance.
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+ and X_2 ~ N(mu_2, C_2) is
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+ Stable version by Dougal J. Sutherland.
+ Params:
+ -- mu1 : Numpy array containing the activations of a layer of the
+ inception net (like returned by the function 'get_predictions')
+ for generated samples.
+ -- mu2 : The sample mean over activations, precalculated on an
+ representive data set.
+ -- sigma1: The covariance matrix over activations for generated samples.
+ -- sigma2: The covariance matrix over activations, precalculated on an
+ representive data set.
+ Returns:
+ -- : The Frechet Distance.
+ """
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert mu1.shape == mu2.shape, \
+ 'Training and test mean vectors have different lengths'
+ assert sigma1.shape == sigma2.shape, \
+ 'Training and test covariances have different dimensions'
+
+ diff = mu1 - mu2
+
+ # Product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = ('fid calculation produces singular product; '
+ 'adding %s to diagonal of cov estimates') % eps
+ print(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError('Imaginary component {}'.format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return (diff.dot(diff) + np.trace(sigma1) + # NOQA
+ np.trace(sigma2) - 2 * tr_covmean)
+
+
+def get_i3d_activations(batched_video,
+ i3d_model,
+ target_endpoint='Logits',
+ flatten=True,
+ grad_enabled=False):
+ """
+ Get features from i3d model and flatten them to 1d feature,
+ valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS
+ VALID_ENDPOINTS = (
+ 'Conv3d_1a_7x7',
+ 'MaxPool3d_2a_3x3',
+ 'Conv3d_2b_1x1',
+ 'Conv3d_2c_3x3',
+ 'MaxPool3d_3a_3x3',
+ 'Mixed_3b',
+ 'Mixed_3c',
+ 'MaxPool3d_4a_3x3',
+ 'Mixed_4b',
+ 'Mixed_4c',
+ 'Mixed_4d',
+ 'Mixed_4e',
+ 'Mixed_4f',
+ 'MaxPool3d_5a_2x2',
+ 'Mixed_5b',
+ 'Mixed_5c',
+ 'Logits',
+ 'Predictions',
+ )
+ """
+ with torch.set_grad_enabled(grad_enabled):
+ feat = i3d_model.extract_features(batched_video.transpose(1, 2),
+ target_endpoint)
+ if flatten:
+ feat = feat.view(feat.size(0), -1)
+
+ return feat
+
+
+# This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py
+# I only fix flake8 errors and do some cleaning here
+
+
+class MaxPool3dSamePadding(nn.MaxPool3d):
+ def compute_pad(self, dim, s):
+ if s % self.stride[dim] == 0:
+ return max(self.kernel_size[dim] - self.stride[dim], 0)
+ else:
+ return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
+
+ def forward(self, x):
+ # compute 'same' padding
+ (batch, channel, t, h, w) = x.size()
+ pad_t = self.compute_pad(0, t)
+ pad_h = self.compute_pad(1, h)
+ pad_w = self.compute_pad(2, w)
+
+ pad_t_f = pad_t // 2
+ pad_t_b = pad_t - pad_t_f
+ pad_h_f = pad_h // 2
+ pad_h_b = pad_h - pad_h_f
+ pad_w_f = pad_w // 2
+ pad_w_b = pad_w - pad_w_f
+
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
+ x = F.pad(x, pad)
+ return super(MaxPool3dSamePadding, self).forward(x)
+
+
+class Unit3D(nn.Module):
+ def __init__(self,
+ in_channels,
+ output_channels,
+ kernel_shape=(1, 1, 1),
+ stride=(1, 1, 1),
+ padding=0,
+ activation_fn=F.relu,
+ use_batch_norm=True,
+ use_bias=False,
+ name='unit_3d'):
+ """Initializes Unit3D module."""
+ super(Unit3D, self).__init__()
+
+ self._output_channels = output_channels
+ self._kernel_shape = kernel_shape
+ self._stride = stride
+ self._use_batch_norm = use_batch_norm
+ self._activation_fn = activation_fn
+ self._use_bias = use_bias
+ self.name = name
+ self.padding = padding
+
+ self.conv3d = nn.Conv3d(
+ in_channels=in_channels,
+ out_channels=self._output_channels,
+ kernel_size=self._kernel_shape,
+ stride=self._stride,
+ padding=0, # we always want padding to be 0 here. We will
+ # dynamically pad based on input size in forward function
+ bias=self._use_bias)
+
+ if self._use_batch_norm:
+ self.bn = nn.BatchNorm3d(self._output_channels,
+ eps=0.001,
+ momentum=0.01)
+
+ def compute_pad(self, dim, s):
+ if s % self._stride[dim] == 0:
+ return max(self._kernel_shape[dim] - self._stride[dim], 0)
+ else:
+ return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
+
+ def forward(self, x):
+ # compute 'same' padding
+ (batch, channel, t, h, w) = x.size()
+ pad_t = self.compute_pad(0, t)
+ pad_h = self.compute_pad(1, h)
+ pad_w = self.compute_pad(2, w)
+
+ pad_t_f = pad_t // 2
+ pad_t_b = pad_t - pad_t_f
+ pad_h_f = pad_h // 2
+ pad_h_b = pad_h - pad_h_f
+ pad_w_f = pad_w // 2
+ pad_w_b = pad_w - pad_w_f
+
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
+ x = F.pad(x, pad)
+
+ x = self.conv3d(x)
+ if self._use_batch_norm:
+ x = self.bn(x)
+ if self._activation_fn is not None:
+ x = self._activation_fn(x)
+ return x
+
+
+class InceptionModule(nn.Module):
+ def __init__(self, in_channels, out_channels, name):
+ super(InceptionModule, self).__init__()
+
+ self.b0 = Unit3D(in_channels=in_channels,
+ output_channels=out_channels[0],
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + '/Branch_0/Conv3d_0a_1x1')
+ self.b1a = Unit3D(in_channels=in_channels,
+ output_channels=out_channels[1],
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + '/Branch_1/Conv3d_0a_1x1')
+ self.b1b = Unit3D(in_channels=out_channels[1],
+ output_channels=out_channels[2],
+ kernel_shape=[3, 3, 3],
+ name=name + '/Branch_1/Conv3d_0b_3x3')
+ self.b2a = Unit3D(in_channels=in_channels,
+ output_channels=out_channels[3],
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + '/Branch_2/Conv3d_0a_1x1')
+ self.b2b = Unit3D(in_channels=out_channels[3],
+ output_channels=out_channels[4],
+ kernel_shape=[3, 3, 3],
+ name=name + '/Branch_2/Conv3d_0b_3x3')
+ self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
+ stride=(1, 1, 1),
+ padding=0)
+ self.b3b = Unit3D(in_channels=in_channels,
+ output_channels=out_channels[5],
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + '/Branch_3/Conv3d_0b_1x1')
+ self.name = name
+
+ def forward(self, x):
+ b0 = self.b0(x)
+ b1 = self.b1b(self.b1a(x))
+ b2 = self.b2b(self.b2a(x))
+ b3 = self.b3b(self.b3a(x))
+ return torch.cat([b0, b1, b2, b3], dim=1)
+
+
+class InceptionI3d(nn.Module):
+ """Inception-v1 I3D architecture.
+ The model is introduced in:
+ Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
+ Joao Carreira, Andrew Zisserman
+ https://arxiv.org/pdf/1705.07750v1.pdf.
+ See also the Inception architecture, introduced in:
+ Going deeper with convolutions
+ Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
+ Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
+ http://arxiv.org/pdf/1409.4842v1.pdf.
+ """
+
+ # Endpoints of the model in order. During construction, all the endpoints up
+ # to a designated `final_endpoint` are returned in a dictionary as the
+ # second return value.
+ VALID_ENDPOINTS = (
+ 'Conv3d_1a_7x7',
+ 'MaxPool3d_2a_3x3',
+ 'Conv3d_2b_1x1',
+ 'Conv3d_2c_3x3',
+ 'MaxPool3d_3a_3x3',
+ 'Mixed_3b',
+ 'Mixed_3c',
+ 'MaxPool3d_4a_3x3',
+ 'Mixed_4b',
+ 'Mixed_4c',
+ 'Mixed_4d',
+ 'Mixed_4e',
+ 'Mixed_4f',
+ 'MaxPool3d_5a_2x2',
+ 'Mixed_5b',
+ 'Mixed_5c',
+ 'Logits',
+ 'Predictions',
+ )
+
+ def __init__(self,
+ num_classes=400,
+ spatial_squeeze=True,
+ final_endpoint='Logits',
+ name='inception_i3d',
+ in_channels=3,
+ dropout_keep_prob=0.5):
+ """Initializes I3D model instance.
+ Args:
+ num_classes: The number of outputs in the logit layer (default 400, which
+ matches the Kinetics dataset).
+ spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
+ before returning (default True).
+ final_endpoint: The model contains many possible endpoints.
+ `final_endpoint` specifies the last endpoint for the model to be built
+ up to. In addition to the output at `final_endpoint`, all the outputs
+ at endpoints up to `final_endpoint` will also be returned, in a
+ dictionary. `final_endpoint` must be one of
+ InceptionI3d.VALID_ENDPOINTS (default 'Logits').
+ name: A string (optional). The name of this module.
+ Raises:
+ ValueError: if `final_endpoint` is not recognized.
+ """
+
+ if final_endpoint not in self.VALID_ENDPOINTS:
+ raise ValueError('Unknown final endpoint %s' % final_endpoint)
+
+ super(InceptionI3d, self).__init__()
+ self._num_classes = num_classes
+ self._spatial_squeeze = spatial_squeeze
+ self._final_endpoint = final_endpoint
+ self.logits = None
+
+ if self._final_endpoint not in self.VALID_ENDPOINTS:
+ raise ValueError('Unknown final endpoint %s' %
+ self._final_endpoint)
+
+ self.end_points = {}
+ end_point = 'Conv3d_1a_7x7'
+ self.end_points[end_point] = Unit3D(in_channels=in_channels,
+ output_channels=64,
+ kernel_shape=[7, 7, 7],
+ stride=(2, 2, 2),
+ padding=(3, 3, 3),
+ name=name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'MaxPool3d_2a_3x3'
+ self.end_points[end_point] = MaxPool3dSamePadding(
+ kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Conv3d_2b_1x1'
+ self.end_points[end_point] = Unit3D(in_channels=64,
+ output_channels=64,
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Conv3d_2c_3x3'
+ self.end_points[end_point] = Unit3D(in_channels=64,
+ output_channels=192,
+ kernel_shape=[3, 3, 3],
+ padding=1,
+ name=name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'MaxPool3d_3a_3x3'
+ self.end_points[end_point] = MaxPool3dSamePadding(
+ kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_3b'
+ self.end_points[end_point] = InceptionModule(192,
+ [64, 96, 128, 16, 32, 32],
+ name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_3c'
+ self.end_points[end_point] = InceptionModule(
+ 256, [128, 128, 192, 32, 96, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'MaxPool3d_4a_3x3'
+ self.end_points[end_point] = MaxPool3dSamePadding(
+ kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4b'
+ self.end_points[end_point] = InceptionModule(
+ 128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4c'
+ self.end_points[end_point] = InceptionModule(
+ 192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4d'
+ self.end_points[end_point] = InceptionModule(
+ 160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4e'
+ self.end_points[end_point] = InceptionModule(
+ 128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4f'
+ self.end_points[end_point] = InceptionModule(
+ 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128],
+ name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'MaxPool3d_5a_2x2'
+ self.end_points[end_point] = MaxPool3dSamePadding(
+ kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_5b'
+ self.end_points[end_point] = InceptionModule(
+ 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128],
+ name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_5c'
+ self.end_points[end_point] = InceptionModule(
+ 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
+ name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Logits'
+ self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1))
+ self.dropout = nn.Dropout(dropout_keep_prob)
+ self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
+ output_channels=self._num_classes,
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ activation_fn=None,
+ use_batch_norm=False,
+ use_bias=True,
+ name='logits')
+
+ self.build()
+
+ def replace_logits(self, num_classes):
+ self._num_classes = num_classes
+ self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
+ output_channels=self._num_classes,
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ activation_fn=None,
+ use_batch_norm=False,
+ use_bias=True,
+ name='logits')
+
+ def build(self):
+ for k in self.end_points.keys():
+ self.add_module(k, self.end_points[k])
+
+ def forward(self, x):
+ for end_point in self.VALID_ENDPOINTS:
+ if end_point in self.end_points:
+ x = self._modules[end_point](
+ x) # use _modules to work with dataparallel
+
+ x = self.logits(self.dropout(self.avg_pool(x)))
+ if self._spatial_squeeze:
+ logits = x.squeeze(3).squeeze(3)
+ # logits is batch X time X classes, which is what we want to work with
+ return logits
+
+ def extract_features(self, x, target_endpoint='Logits'):
+ for end_point in self.VALID_ENDPOINTS:
+ if end_point in self.end_points:
+ x = self._modules[end_point](x)
+ if end_point == target_endpoint:
+ break
+ if target_endpoint == 'Logits':
+ return x.mean(4).mean(3).mean(2)
+ else:
+ return x
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/core/trainer.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b52e7fad9260f904375c295392f208f0ac624aef
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/trainer.py
@@ -0,0 +1,399 @@
+import os
+import glob
+import logging
+import importlib
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR
+from core.loss import AdversarialLoss
+from core.dataset import TrainDataset
+from model.modules.flow_comp import FlowCompletionLoss
+
+
+class Trainer:
+ def __init__(self, config):
+ self.config = config
+ self.epoch = 0
+ self.iteration = 0
+ self.num_local_frames = config['train_data_loader']['num_local_frames']
+ self.num_ref_frames = config['train_data_loader']['num_ref_frames']
+ self.spynet_lr = config['trainer'].get('spynet_lr', 1.0)
+
+ # setup data set and data loader
+ self.train_dataset = TrainDataset(config['train_data_loader'])
+
+ self.train_sampler = None
+ self.train_args = config['trainer']
+ if config['distributed']:
+ self.train_sampler = DistributedSampler(
+ self.train_dataset,
+ num_replicas=config['world_size'],
+ rank=config['global_rank'])
+
+ self.train_loader = DataLoader(
+ self.train_dataset,
+ batch_size=self.train_args['batch_size'] // config['world_size'],
+ shuffle=(self.train_sampler is None),
+ num_workers=self.train_args['num_workers'],
+ sampler=self.train_sampler)
+
+ # set loss functions
+ self.adversarial_loss = AdversarialLoss(
+ type=self.config['losses']['GAN_LOSS'])
+ self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
+ self.l1_loss = nn.L1Loss()
+ self.flow_comp_loss = FlowCompletionLoss().to(self.config['device'])
+
+ # setup models including generator and discriminator
+ net = importlib.import_module('model.' + config['model']['net'])
+ self.netG = net.InpaintGenerator()
+ print(self.netG)
+ self.netG = self.netG.to(self.config['device'])
+ if not self.config['model']['no_dis']:
+ self.netD = net.Discriminator(
+ in_channels=3,
+ use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
+ self.netD = self.netD.to(self.config['device'])
+
+ # setup optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+ self.load()
+
+ if config['distributed']:
+ self.netG = DDP(self.netG,
+ device_ids=[self.config['local_rank']],
+ output_device=self.config['local_rank'],
+ broadcast_buffers=True,
+ find_unused_parameters=True)
+ if not self.config['model']['no_dis']:
+ self.netD = DDP(self.netD,
+ device_ids=[self.config['local_rank']],
+ output_device=self.config['local_rank'],
+ broadcast_buffers=True,
+ find_unused_parameters=False)
+
+ # set summary writer
+ self.dis_writer = None
+ self.gen_writer = None
+ self.summary = {}
+ if self.config['global_rank'] == 0 or (not config['distributed']):
+ self.dis_writer = SummaryWriter(
+ os.path.join(config['save_dir'], 'dis'))
+ self.gen_writer = SummaryWriter(
+ os.path.join(config['save_dir'], 'gen'))
+
+ def setup_optimizers(self):
+ """Set up optimizers."""
+ backbone_params = []
+ spynet_params = []
+ for name, param in self.netG.named_parameters():
+ if 'update_spynet' in name:
+ spynet_params.append(param)
+ else:
+ backbone_params.append(param)
+
+ optim_params = [
+ {
+ 'params': backbone_params,
+ 'lr': self.config['trainer']['lr']
+ },
+ { # finetuning learning rate for spynet
+ 'params': spynet_params,
+ 'lr': self.config['trainer']['lr'] * self.spynet_lr
+ },
+ ]
+
+ self.optimG = torch.optim.Adam(optim_params,
+ betas=(self.config['trainer']['beta1'],
+ self.config['trainer']['beta2']))
+
+ if not self.config['model']['no_dis']:
+ self.optimD = torch.optim.Adam(
+ self.netD.parameters(),
+ lr=self.config['trainer']['lr'],
+ betas=(self.config['trainer']['beta1'],
+ self.config['trainer']['beta2']))
+
+ def setup_schedulers(self):
+ """Set up schedulers."""
+ scheduler_opt = self.config['trainer']['scheduler']
+ scheduler_type = scheduler_opt.pop('type')
+
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
+ self.scheG = MultiStepRestartLR(
+ self.optimG,
+ milestones=scheduler_opt['milestones'],
+ gamma=scheduler_opt['gamma'])
+ self.scheD = MultiStepRestartLR(
+ self.optimD,
+ milestones=scheduler_opt['milestones'],
+ gamma=scheduler_opt['gamma'])
+ elif scheduler_type == 'CosineAnnealingRestartLR':
+ self.scheG = CosineAnnealingRestartLR(
+ self.optimG,
+ periods=scheduler_opt['periods'],
+ restart_weights=scheduler_opt['restart_weights'])
+ self.scheD = CosineAnnealingRestartLR(
+ self.optimD,
+ periods=scheduler_opt['periods'],
+ restart_weights=scheduler_opt['restart_weights'])
+ else:
+ raise NotImplementedError(
+ f'Scheduler {scheduler_type} is not implemented yet.')
+
+ def update_learning_rate(self):
+ """Update learning rate."""
+ self.scheG.step()
+ self.scheD.step()
+
+ def get_lr(self):
+ """Get current learning rate."""
+ return self.optimG.param_groups[0]['lr']
+
+ def add_summary(self, writer, name, val):
+ """Add tensorboard summary."""
+ if name not in self.summary:
+ self.summary[name] = 0
+ self.summary[name] += val
+ if writer is not None and self.iteration % 100 == 0:
+ writer.add_scalar(name, self.summary[name] / 100, self.iteration)
+ self.summary[name] = 0
+
+ def load(self):
+ """Load netG (and netD)."""
+ # get the latest checkpoint
+ model_path = self.config['save_dir']
+ if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
+ latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
+ 'r').read().splitlines()[-1]
+ else:
+ ckpts = [
+ os.path.basename(i).split('.pth')[0]
+ for i in glob.glob(os.path.join(model_path, '*.pth'))
+ ]
+ ckpts.sort()
+ latest_epoch = ckpts[-1] if len(ckpts) > 0 else None
+
+ if latest_epoch is not None:
+ gen_path = os.path.join(model_path,
+ f'gen_{int(latest_epoch):06d}.pth')
+ dis_path = os.path.join(model_path,
+ f'dis_{int(latest_epoch):06d}.pth')
+ opt_path = os.path.join(model_path,
+ f'opt_{int(latest_epoch):06d}.pth')
+
+ if self.config['global_rank'] == 0:
+ print(f'Loading model from {gen_path}...')
+ dataG = torch.load(gen_path, map_location=self.config['device'])
+ self.netG.load_state_dict(dataG)
+ if not self.config['model']['no_dis']:
+ dataD = torch.load(dis_path,
+ map_location=self.config['device'])
+ self.netD.load_state_dict(dataD)
+
+ data_opt = torch.load(opt_path, map_location=self.config['device'])
+ self.optimG.load_state_dict(data_opt['optimG'])
+ self.scheG.load_state_dict(data_opt['scheG'])
+ if not self.config['model']['no_dis']:
+ self.optimD.load_state_dict(data_opt['optimD'])
+ self.scheD.load_state_dict(data_opt['scheD'])
+ self.epoch = data_opt['epoch']
+ self.iteration = data_opt['iteration']
+
+ else:
+ if self.config['global_rank'] == 0:
+ print('Warnning: There is no trained model found.'
+ 'An initialized model will be used.')
+
+ def save(self, it):
+ """Save parameters every eval_epoch"""
+ if self.config['global_rank'] == 0:
+ # configure path
+ gen_path = os.path.join(self.config['save_dir'],
+ f'gen_{it:06d}.pth')
+ dis_path = os.path.join(self.config['save_dir'],
+ f'dis_{it:06d}.pth')
+ opt_path = os.path.join(self.config['save_dir'],
+ f'opt_{it:06d}.pth')
+ print(f'\nsaving model to {gen_path} ...')
+
+ # remove .module for saving
+ if isinstance(self.netG, torch.nn.DataParallel) \
+ or isinstance(self.netG, DDP):
+ netG = self.netG.module
+ if not self.config['model']['no_dis']:
+ netD = self.netD.module
+ else:
+ netG = self.netG
+ if not self.config['model']['no_dis']:
+ netD = self.netD
+
+ # save checkpoints
+ torch.save(netG.state_dict(), gen_path)
+ if not self.config['model']['no_dis']:
+ torch.save(netD.state_dict(), dis_path)
+ torch.save(
+ {
+ 'epoch': self.epoch,
+ 'iteration': self.iteration,
+ 'optimG': self.optimG.state_dict(),
+ 'optimD': self.optimD.state_dict(),
+ 'scheG': self.scheG.state_dict(),
+ 'scheD': self.scheD.state_dict()
+ }, opt_path)
+ else:
+ torch.save(
+ {
+ 'epoch': self.epoch,
+ 'iteration': self.iteration,
+ 'optimG': self.optimG.state_dict(),
+ 'scheG': self.scheG.state_dict()
+ }, opt_path)
+
+ latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt')
+ os.system(f"echo {it:06d} > {latest_path}")
+
+ def train(self):
+ """training entry"""
+ pbar = range(int(self.train_args['iterations']))
+ if self.config['global_rank'] == 0:
+ pbar = tqdm(pbar,
+ initial=self.iteration,
+ dynamic_ncols=True,
+ smoothing=0.01)
+
+ os.makedirs('logs', exist_ok=True)
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(filename)s[line:%(lineno)d]"
+ "%(levelname)s %(message)s",
+ datefmt="%a, %d %b %Y %H:%M:%S",
+ filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log",
+ filemode='w')
+
+ while True:
+ self.epoch += 1
+ if self.config['distributed']:
+ self.train_sampler.set_epoch(self.epoch)
+
+ self._train_epoch(pbar)
+ if self.iteration > self.train_args['iterations']:
+ break
+ print('\nEnd training....')
+
+ def _train_epoch(self, pbar):
+ """Process input and calculate loss every training epoch"""
+ device = self.config['device']
+
+ for frames, masks, _ in self.train_loader:
+ self.iteration += 1
+
+ frames, masks = frames.to(device), masks.to(device)
+ l_t = self.num_local_frames
+ b, t, c, h, w = frames.size()
+
+ masked_frames = (frames * (1 - masks).float())
+ gt_local_frames = (frames[:, :l_t, ...] + 1) / 2
+
+ pred_imgs, pred_flows = self.netG(masked_frames, l_t)
+ pred_imgs = pred_imgs.view(b, -1, c, h, w)
+ comp_imgs = frames * (1. - masks) + masks * pred_imgs
+
+ # compute flow completion loss
+ flow_loss = self.flow_comp_loss(pred_flows, gt_local_frames)
+
+ gen_loss = 0
+ dis_loss = 0
+
+ if not self.config['model']['no_dis']:
+ # discriminator adversarial loss
+ real_clip = self.netD(frames)
+ fake_clip = self.netD(comp_imgs.detach())
+ dis_real_loss = self.adversarial_loss(real_clip, True, True)
+ dis_fake_loss = self.adversarial_loss(fake_clip, False, True)
+ dis_loss += (dis_real_loss + dis_fake_loss) / 2
+ self.add_summary(self.dis_writer, 'loss/dis_vid_fake',
+ dis_fake_loss.item())
+ self.add_summary(self.dis_writer, 'loss/dis_vid_real',
+ dis_real_loss.item())
+ self.optimD.zero_grad()
+ dis_loss.backward()
+ self.optimD.step()
+
+ # generator adversarial loss
+ gen_clip = self.netD(comp_imgs)
+ gan_loss = self.adversarial_loss(gen_clip, True, False)
+ gan_loss = gan_loss \
+ * self.config['losses']['adversarial_weight']
+ gen_loss += gan_loss
+ self.add_summary(self.gen_writer, 'loss/gan_loss',
+ gan_loss.item())
+
+ flow_loss = flow_loss * self.config['losses']['flow_weight']
+ gen_loss += flow_loss
+ self.add_summary(self.gen_writer, 'loss/flow_loss',
+ flow_loss.item())
+
+ # generator l1 loss
+ hole_loss = self.l1_loss(pred_imgs * masks, frames * masks)
+ hole_loss = hole_loss / torch.mean(masks) \
+ * self.config['losses']['hole_weight']
+ gen_loss += hole_loss
+ self.add_summary(self.gen_writer, 'loss/hole_loss',
+ hole_loss.item())
+
+ valid_loss = self.l1_loss(pred_imgs * (1 - masks),
+ frames * (1 - masks))
+ valid_loss = valid_loss / torch.mean(1-masks) \
+ * self.config['losses']['valid_weight']
+ gen_loss += valid_loss
+ self.add_summary(self.gen_writer, 'loss/valid_loss',
+ valid_loss.item())
+
+ self.optimG.zero_grad()
+ gen_loss.backward()
+ self.optimG.step()
+
+ self.update_learning_rate()
+
+ # console logs
+ if self.config['global_rank'] == 0:
+ pbar.update(1)
+ if not self.config['model']['no_dis']:
+ pbar.set_description((f"flow: {flow_loss.item():.3f}; "
+ f"d: {dis_loss.item():.3f}; "
+ f"hole: {hole_loss.item():.3f}; "
+ f"valid: {valid_loss.item():.3f}"))
+ else:
+ pbar.set_description((f"flow: {flow_loss.item():.3f}; "
+ f"hole: {hole_loss.item():.3f}; "
+ f"valid: {valid_loss.item():.3f}"))
+
+ if self.iteration % self.train_args['log_freq'] == 0:
+ if not self.config['model']['no_dis']:
+ logging.info(f"[Iter {self.iteration}] "
+ f"flow: {flow_loss.item():.4f}; "
+ f"d: {dis_loss.item():.4f}; "
+ f"hole: {hole_loss.item():.4f}; "
+ f"valid: {valid_loss.item():.4f}")
+ else:
+ logging.info(f"[Iter {self.iteration}] "
+ f"flow: {flow_loss.item():.4f}; "
+ f"hole: {hole_loss.item():.4f}; "
+ f"valid: {valid_loss.item():.4f}")
+
+ # saving models
+ if self.iteration % self.train_args['save_freq'] == 0:
+ self.save(int(self.iteration))
+
+ if self.iteration > self.train_args['iterations']:
+ break
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/core/utils.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a173372157b69e11c28961e7760e78cedd81eec
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/core/utils.py
@@ -0,0 +1,330 @@
+import os
+import io
+import cv2
+import random
+import numpy as np
+from PIL import Image, ImageOps
+import zipfile
+
+import torch
+import matplotlib
+import matplotlib.patches as patches
+from matplotlib.path import Path
+from matplotlib import pyplot as plt
+from torchvision import transforms
+
+# matplotlib.use('agg')
+
+# ###########################################################################
+# Directory IO
+# ###########################################################################
+
+
+def read_dirnames_under_root(root_dir):
+ dirnames = [
+ name for i, name in enumerate(sorted(os.listdir(root_dir)))
+ if os.path.isdir(os.path.join(root_dir, name))
+ ]
+ print(f'Reading directories under {root_dir}, num: {len(dirnames)}')
+ return dirnames
+
+
+class TrainZipReader(object):
+ file_dict = dict()
+
+ def __init__(self):
+ super(TrainZipReader, self).__init__()
+
+ @staticmethod
+ def build_file_dict(path):
+ file_dict = TrainZipReader.file_dict
+ if path in file_dict:
+ return file_dict[path]
+ else:
+ file_handle = zipfile.ZipFile(path, 'r')
+ file_dict[path] = file_handle
+ return file_dict[path]
+
+ @staticmethod
+ def imread(path, idx):
+ zfile = TrainZipReader.build_file_dict(path)
+ filelist = zfile.namelist()
+ filelist.sort()
+ data = zfile.read(filelist[idx])
+ #
+ im = Image.open(io.BytesIO(data))
+ return im
+
+
+class TestZipReader(object):
+ file_dict = dict()
+
+ def __init__(self):
+ super(TestZipReader, self).__init__()
+
+ @staticmethod
+ def build_file_dict(path):
+ file_dict = TestZipReader.file_dict
+ if path in file_dict:
+ return file_dict[path]
+ else:
+ file_handle = zipfile.ZipFile(path, 'r')
+ file_dict[path] = file_handle
+ return file_dict[path]
+
+ @staticmethod
+ def imread(path, idx):
+ zfile = TestZipReader.build_file_dict(path)
+ filelist = zfile.namelist()
+ filelist.sort()
+ data = zfile.read(filelist[idx])
+ file_bytes = np.asarray(bytearray(data), dtype=np.uint8)
+ im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
+ im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
+ # im = Image.open(io.BytesIO(data))
+ return im
+
+
+# ###########################################################################
+# Data augmentation
+# ###########################################################################
+
+
+def to_tensors():
+ return transforms.Compose([Stack(), ToTorchFormatTensor()])
+
+
+class GroupRandomHorizontalFlowFlip(object):
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
+ """
+ def __init__(self, is_flow=True):
+ self.is_flow = is_flow
+
+ def __call__(self, img_group, mask_group, flowF_group, flowB_group):
+ v = random.random()
+ if v < 0.5:
+ ret_img = [
+ img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group
+ ]
+ ret_mask = [
+ mask.transpose(Image.FLIP_LEFT_RIGHT) for mask in mask_group
+ ]
+ ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group]
+ ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group]
+ return ret_img, ret_mask, ret_flowF, ret_flowB
+ else:
+ return img_group, mask_group, flowF_group, flowB_group
+
+
+class GroupRandomHorizontalFlip(object):
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
+ """
+ def __init__(self, is_flow=False):
+ self.is_flow = is_flow
+
+ def __call__(self, img_group, is_flow=False):
+ v = random.random()
+ if v < 0.5:
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
+ if self.is_flow:
+ for i in range(0, len(ret), 2):
+ # invert flow pixel values when flipping
+ ret[i] = ImageOps.invert(ret[i])
+ return ret
+ else:
+ return img_group
+
+
+class Stack(object):
+ def __init__(self, roll=False):
+ self.roll = roll
+
+ def __call__(self, img_group):
+ mode = img_group[0].mode
+ if mode == '1':
+ img_group = [img.convert('L') for img in img_group]
+ mode = 'L'
+ if mode == 'L':
+ return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
+ elif mode == 'RGB':
+ if self.roll:
+ return np.stack([np.array(x)[:, :, ::-1] for x in img_group],
+ axis=2)
+ else:
+ return np.stack(img_group, axis=2)
+ else:
+ raise NotImplementedError(f"Image mode {mode}")
+
+
+class ToTorchFormatTensor(object):
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
+ def __init__(self, div=True):
+ self.div = div
+
+ def __call__(self, pic):
+ if isinstance(pic, np.ndarray):
+ # numpy img: [L, C, H, W]
+ img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
+ else:
+ # handle PIL Image
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(
+ pic.tobytes()))
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
+ # put it from HWC to CHW format
+ # yikes, this transpose takes 80% of the loading time/CPU
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
+ img = img.float().div(255) if self.div else img.float()
+ return img
+
+
+# ###########################################################################
+# Create masks with random shape
+# ###########################################################################
+
+
+def create_random_shape_with_random_motion(video_length,
+ imageHeight=240,
+ imageWidth=432):
+ # get a random shape
+ height = random.randint(imageHeight // 3, imageHeight - 1)
+ width = random.randint(imageWidth // 3, imageWidth - 1)
+ edge_num = random.randint(6, 8)
+ ratio = random.randint(6, 8) / 10
+ region = get_random_shape(edge_num=edge_num,
+ ratio=ratio,
+ height=height,
+ width=width)
+ region_width, region_height = region.size
+ # get random position
+ x, y = random.randint(0, imageHeight - region_height), random.randint(
+ 0, imageWidth - region_width)
+ velocity = get_random_velocity(max_speed=3)
+ m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
+ masks = [m.convert('L')]
+ # return fixed masks
+ if random.uniform(0, 1) > 0.5:
+ return masks * video_length
+ # return moving masks
+ for _ in range(video_length - 1):
+ x, y, velocity = random_move_control_points(x,
+ y,
+ imageHeight,
+ imageWidth,
+ velocity,
+ region.size,
+ maxLineAcceleration=(3,
+ 0.5),
+ maxInitSpeed=3)
+ m = Image.fromarray(
+ np.zeros((imageHeight, imageWidth)).astype(np.uint8))
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
+ masks.append(m.convert('L'))
+ return masks
+
+
+def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
+ '''
+ There is the initial point and 3 points per cubic bezier curve.
+ Thus, the curve will only pass though n points, which will be the sharp edges.
+ The other 2 modify the shape of the bezier curve.
+ edge_num, Number of possibly sharp edges
+ points_num, number of points in the Path
+ ratio, (0, 1) magnitude of the perturbation from the unit circle,
+ '''
+ points_num = edge_num * 3 + 1
+ angles = np.linspace(0, 2 * np.pi, points_num)
+ codes = np.full(points_num, Path.CURVE4)
+ codes[0] = Path.MOVETO
+ # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
+ verts = np.stack((np.cos(angles), np.sin(angles))).T * \
+ (2*ratio*np.random.random(points_num)+1-ratio)[:, None]
+ verts[-1, :] = verts[0, :]
+ path = Path(verts, codes)
+ # draw paths into images
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ patch = patches.PathPatch(path, facecolor='black', lw=2)
+ ax.add_patch(patch)
+ ax.set_xlim(np.min(verts) * 1.1, np.max(verts) * 1.1)
+ ax.set_ylim(np.min(verts) * 1.1, np.max(verts) * 1.1)
+ ax.axis('off') # removes the axis to leave only the shape
+ fig.canvas.draw()
+ # convert plt images into numpy images
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ data = data.reshape((fig.canvas.get_width_height()[::-1] + (3, )))
+ plt.close(fig)
+ # postprocess
+ data = cv2.resize(data, (width, height))[:, :, 0]
+ data = (1 - np.array(data > 0).astype(np.uint8)) * 255
+ corrdinates = np.where(data > 0)
+ xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
+ corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
+ region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
+ return region
+
+
+def random_accelerate(velocity, maxAcceleration, dist='uniform'):
+ speed, angle = velocity
+ d_speed, d_angle = maxAcceleration
+ if dist == 'uniform':
+ speed += np.random.uniform(-d_speed, d_speed)
+ angle += np.random.uniform(-d_angle, d_angle)
+ elif dist == 'guassian':
+ speed += np.random.normal(0, d_speed / 2)
+ angle += np.random.normal(0, d_angle / 2)
+ else:
+ raise NotImplementedError(
+ f'Distribution type {dist} is not supported.')
+ return (speed, angle)
+
+
+def get_random_velocity(max_speed=3, dist='uniform'):
+ if dist == 'uniform':
+ speed = np.random.uniform(max_speed)
+ elif dist == 'guassian':
+ speed = np.abs(np.random.normal(0, max_speed / 2))
+ else:
+ raise NotImplementedError(
+ f'Distribution type {dist} is not supported.')
+ angle = np.random.uniform(0, 2 * np.pi)
+ return (speed, angle)
+
+
+def random_move_control_points(X,
+ Y,
+ imageHeight,
+ imageWidth,
+ lineVelocity,
+ region_size,
+ maxLineAcceleration=(3, 0.5),
+ maxInitSpeed=3):
+ region_width, region_height = region_size
+ speed, angle = lineVelocity
+ X += int(speed * np.cos(angle))
+ Y += int(speed * np.sin(angle))
+ lineVelocity = random_accelerate(lineVelocity,
+ maxLineAcceleration,
+ dist='guassian')
+ if ((X > imageHeight - region_height) or (X < 0)
+ or (Y > imageWidth - region_width) or (Y < 0)):
+ lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
+ new_X = np.clip(X, 0, imageHeight - region_height)
+ new_Y = np.clip(Y, 0, imageWidth - region_width)
+ return new_X, new_Y, lineVelocity
+
+
+if __name__ == '__main__':
+
+ trials = 10
+ for _ in range(trials):
+ video_length = 10
+ # The returned masks are either stationary (50%) or moving (50%)
+ masks = create_random_shape_with_random_motion(video_length,
+ imageHeight=240,
+ imageWidth=432)
+
+ for m in masks:
+ cv2.imshow('mask', np.array(m))
+ cv2.waitKey(500)
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/evaluate.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8f70789ce9f510767bf6cae12d4f374749ad8ec
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/evaluate.py
@@ -0,0 +1,176 @@
+# -*- coding: utf-8 -*-
+import cv2
+import numpy as np
+import importlib
+import os
+import argparse
+from PIL import Image
+
+import torch
+from torch.utils.data import DataLoader
+
+from core.dataset import TestDataset
+from core.metrics import calc_psnr_and_ssim, calculate_i3d_activations, calculate_vfid, init_i3d_model
+
+# global variables
+w, h = 432, 240
+ref_length = 10
+neighbor_stride = 5
+default_fps = 24
+
+
+# sample reference frames from the whole video
+def get_ref_index(neighbor_ids, length):
+ ref_index = []
+ for i in range(0, length, ref_length):
+ if i not in neighbor_ids:
+ ref_index.append(i)
+ return ref_index
+
+
+def main_worker(args):
+ args.size = (w, h)
+ # set up datasets and data loader
+ assert (args.dataset == 'davis') or args.dataset == 'youtube-vos', \
+ f"{args.dataset} dataset is not supported"
+ test_dataset = TestDataset(args)
+
+ test_loader = DataLoader(test_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=args.num_workers)
+
+ # set up models
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ net = importlib.import_module('model.' + args.model)
+ model = net.InpaintGenerator().to(device)
+ data = torch.load(args.ckpt, map_location=device)
+ model.load_state_dict(data)
+ print(f'Loading from: {args.ckpt}')
+ model.eval()
+
+ total_frame_psnr = []
+ total_frame_ssim = []
+
+ output_i3d_activations = []
+ real_i3d_activations = []
+
+ print('Start evaluation...')
+
+ # create results directory
+ result_path = os.path.join('results', f'{args.model}_{args.dataset}')
+ if not os.path.exists(result_path):
+ os.makedirs(result_path)
+ eval_summary = open(
+ os.path.join(result_path, f"{args.model}_{args.dataset}_metrics.txt"),
+ "w")
+
+ i3d_model = init_i3d_model()
+
+ for index, items in enumerate(test_loader):
+ frames, masks, video_name, frames_PIL = items
+
+ video_length = frames.size(1)
+ frames, masks = frames.to(device), masks.to(device)
+ ori_frames = frames_PIL
+ ori_frames = [
+ ori_frames[i].squeeze().cpu().numpy() for i in range(video_length)
+ ]
+ comp_frames = [None] * video_length
+
+ # complete holes by our model
+ for f in range(0, video_length, neighbor_stride):
+ neighbor_ids = [
+ i for i in range(max(0, f - neighbor_stride),
+ min(video_length, f + neighbor_stride + 1))
+ ]
+ ref_ids = get_ref_index(neighbor_ids, video_length)
+ selected_imgs = frames[:1, neighbor_ids + ref_ids, :, :, :]
+ selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
+ with torch.no_grad():
+ masked_frames = selected_imgs * (1 - selected_masks)
+ pred_img, _ = model(masked_frames, len(neighbor_ids))
+
+ pred_img = (pred_img + 1) / 2
+ pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
+ binary_masks = masks[0, neighbor_ids, :, :, :].cpu().permute(
+ 0, 2, 3, 1).numpy().astype(np.uint8)
+ for i in range(len(neighbor_ids)):
+ idx = neighbor_ids[i]
+ img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
+ + ori_frames[idx] * (1 - binary_masks[i])
+ if comp_frames[idx] is None:
+ comp_frames[idx] = img
+ else:
+ comp_frames[idx] = comp_frames[idx].astype(
+ np.float32) * 0.5 + img.astype(np.float32) * 0.5
+
+ # calculate metrics
+ cur_video_psnr = []
+ cur_video_ssim = []
+ comp_PIL = [] # to calculate VFID
+ frames_PIL = []
+ for ori, comp in zip(ori_frames, comp_frames):
+ psnr, ssim = calc_psnr_and_ssim(ori, comp)
+
+ cur_video_psnr.append(psnr)
+ cur_video_ssim.append(ssim)
+
+ total_frame_psnr.append(psnr)
+ total_frame_ssim.append(ssim)
+
+ frames_PIL.append(Image.fromarray(ori.astype(np.uint8)))
+ comp_PIL.append(Image.fromarray(comp.astype(np.uint8)))
+ cur_psnr = sum(cur_video_psnr) / len(cur_video_psnr)
+ cur_ssim = sum(cur_video_ssim) / len(cur_video_ssim)
+
+ # saving i3d activations
+ frames_i3d, comp_i3d = calculate_i3d_activations(frames_PIL,
+ comp_PIL,
+ i3d_model,
+ device=device)
+ real_i3d_activations.append(frames_i3d)
+ output_i3d_activations.append(comp_i3d)
+
+ print(
+ f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}'
+ )
+ eval_summary.write(
+ f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}\n'
+ )
+
+ # saving images for evaluating warpping errors
+ if args.save_results:
+ save_frame_path = os.path.join(result_path, video_name[0])
+ os.makedirs(save_frame_path, exist_ok=False)
+
+ for i, frame in enumerate(comp_frames):
+ cv2.imwrite(
+ os.path.join(save_frame_path,
+ str(i).zfill(5) + '.png'),
+ cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR))
+
+ avg_frame_psnr = sum(total_frame_psnr) / len(total_frame_psnr)
+ avg_frame_ssim = sum(total_frame_ssim) / len(total_frame_ssim)
+
+ fid_score = calculate_vfid(real_i3d_activations, output_i3d_activations)
+ print('Finish evaluation... Average Frame PSNR/SSIM/VFID: '
+ f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}')
+ eval_summary.write(
+ 'Finish evaluation... Average Frame PSNR/SSIM/VFID: '
+ f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}')
+ eval_summary.close()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='E2FGVI')
+ parser.add_argument('--dataset',
+ choices=['davis', 'youtube-vos'],
+ type=str)
+ parser.add_argument('--data_root', type=str, required=True)
+ parser.add_argument('--model', choices=['e2fgvi', 'e2fgvi_hq'], type=str)
+ parser.add_argument('--ckpt', type=str, required=True)
+ parser.add_argument('--save_results', action='store_true', default=False)
+ parser.add_argument('--num_workers', default=4, type=int)
+ args = parser.parse_args()
+ main_worker(args)
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/model/__init__.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/model/e2fgvi.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/e2fgvi.py
new file mode 100644
index 0000000000000000000000000000000000000000..cac63a3786e71b1e692e28996128b5869b9be3fd
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/e2fgvi.py
@@ -0,0 +1,350 @@
+''' Towards An End-to-End Framework for Video Inpainting
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from E2FGVI.model.modules.flow_comp import SPyNet
+from E2FGVI.model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
+from E2FGVI.model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp
+from E2FGVI.model.modules.spectral_norm import spectral_norm as _spectral_norm
+
+
+class BaseNetwork(nn.Module):
+ def __init__(self):
+ super(BaseNetwork, self).__init__()
+
+ def print_network(self):
+ if isinstance(self, list):
+ self = self[0]
+ num_params = 0
+ for param in self.parameters():
+ num_params += param.numel()
+ print(
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
+ 'To see the architecture, do print(network).' %
+ (type(self).__name__, num_params / 1000000))
+
+ def init_weights(self, init_type='normal', gain=0.02):
+ '''
+ initialize network's weights
+ init_type: normal | xavier | kaiming | orthogonal
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
+ '''
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find('InstanceNorm2d') != -1:
+ if hasattr(m, 'weight') and m.weight is not None:
+ nn.init.constant_(m.weight.data, 1.0)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
+ or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ nn.init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'xavier_uniform':
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == 'kaiming':
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ nn.init.orthogonal_(m.weight.data, gain=gain)
+ elif init_type == 'none': # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError(
+ 'initialization method [%s] is not implemented' %
+ init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+
+ self.apply(init_func)
+
+ # propagate to children
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights(init_type, gain)
+
+
+class Encoder(nn.Module):
+ def __init__(self):
+ super(Encoder, self).__init__()
+ self.group = [1, 2, 4, 8, 1]
+ self.layers = nn.ModuleList([
+ nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True)
+ ])
+
+ def forward(self, x):
+ bt, c, h, w = x.size()
+ h, w = h // 4, w // 4
+ out = x
+ for i, layer in enumerate(self.layers):
+ if i == 8:
+ x0 = out
+ if i > 8 and i % 2 == 0:
+ g = self.group[(i - 8) // 2]
+ x = x0.view(bt, g, -1, h, w)
+ o = out.view(bt, g, -1, h, w)
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
+ out = layer(out)
+ return out
+
+
+class deconv(nn.Module):
+ def __init__(self,
+ input_channel,
+ output_channel,
+ kernel_size=3,
+ padding=0):
+ super().__init__()
+ self.conv = nn.Conv2d(input_channel,
+ output_channel,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding)
+
+ def forward(self, x):
+ x = F.interpolate(x,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=True)
+ return self.conv(x)
+
+
+class InpaintGenerator(BaseNetwork):
+ def __init__(self, init_weights=True):
+ super(InpaintGenerator, self).__init__()
+ channel = 256
+ hidden = 512
+
+ # encoder
+ self.encoder = Encoder()
+
+ # decoder
+ self.decoder = nn.Sequential(
+ deconv(channel // 2, 128, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ deconv(64, 64, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
+
+ # feature propagation module
+ self.feat_prop_module = BidirectionalPropagation(channel // 2)
+
+ # soft split and soft composition
+ kernel_size = (7, 7)
+ padding = (3, 3)
+ stride = (3, 3)
+ output_size = (60, 108)
+ t2t_params = {
+ 'kernel_size': kernel_size,
+ 'stride': stride,
+ 'padding': padding,
+ 'output_size': output_size
+ }
+ self.ss = SoftSplit(channel // 2,
+ hidden,
+ kernel_size,
+ stride,
+ padding,
+ t2t_param=t2t_params)
+ self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size,
+ stride, padding)
+
+ n_vecs = 1
+ for i, d in enumerate(kernel_size):
+ n_vecs *= int((output_size[i] + 2 * padding[i] -
+ (d - 1) - 1) / stride[i] + 1)
+
+ blocks = []
+ depths = 8
+ num_heads = [4] * depths
+ window_size = [(5, 9)] * depths
+ focal_windows = [(5, 9)] * depths
+ focal_levels = [2] * depths
+ pool_method = "fc"
+
+ for i in range(depths):
+ blocks.append(
+ TemporalFocalTransformerBlock(dim=hidden,
+ num_heads=num_heads[i],
+ window_size=window_size[i],
+ focal_level=focal_levels[i],
+ focal_window=focal_windows[i],
+ n_vecs=n_vecs,
+ t2t_params=t2t_params,
+ pool_method=pool_method))
+ self.transformer = nn.Sequential(*blocks)
+
+ if init_weights:
+ self.init_weights()
+ # Need to initial the weights of MSDeformAttn specifically
+ for m in self.modules():
+ if isinstance(m, SecondOrderDeformableAlignment):
+ m.init_offset()
+
+ # flow completion network
+ self.update_spynet = SPyNet()
+
+ def forward_bidirect_flow(self, masked_local_frames):
+ b, l_t, c, h, w = masked_local_frames.size()
+
+ # compute forward and backward flows of masked frames
+ masked_local_frames = F.interpolate(masked_local_frames.view(
+ -1, c, h, w),
+ scale_factor=1 / 4,
+ mode='bilinear',
+ align_corners=True,
+ recompute_scale_factor=True)
+ masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
+ w // 4)
+ mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
+ pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
+
+ pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
+ w // 4)
+ pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
+ w // 4)
+
+ return pred_flows_forward, pred_flows_backward
+
+ def forward(self, masked_frames, num_local_frames):
+ l_t = num_local_frames
+ b, t, ori_c, ori_h, ori_w = masked_frames.size()
+
+ # normalization before feeding into the flow completion module
+ masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
+ pred_flows = self.forward_bidirect_flow(masked_local_frames)
+
+ # extracting features and performing the feature propagation on local features
+ enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
+ _, c, h, w = enc_feat.size()
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
+ local_feat = self.feat_prop_module(local_feat, pred_flows[0],
+ pred_flows[1])
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
+
+ # content hallucination through stacking multiple temporal focal transformer blocks
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b)
+ trans_feat = self.transformer(trans_feat)
+ trans_feat = self.sc(trans_feat, t)
+ trans_feat = trans_feat.view(b, t, -1, h, w)
+ enc_feat = enc_feat + trans_feat
+
+ # decode frames from features
+ output = self.decoder(enc_feat.view(b * t, c, h, w))
+ output = torch.tanh(output)
+ return output, pred_flows
+
+
+# ######################################################################
+# Discriminator for Temporal Patch GAN
+# ######################################################################
+
+
+class Discriminator(BaseNetwork):
+ def __init__(self,
+ in_channels=3,
+ use_sigmoid=False,
+ use_spectral_norm=True,
+ init_weights=True):
+ super(Discriminator, self).__init__()
+ self.use_sigmoid = use_sigmoid
+ nf = 32
+
+ self.conv = nn.Sequential(
+ spectral_norm(
+ nn.Conv3d(in_channels=in_channels,
+ out_channels=nf * 1,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=1,
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(64, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 1,
+ nf * 2,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(128, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 2,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2)))
+
+ if init_weights:
+ self.init_weights()
+
+ def forward(self, xs):
+ # T, C, H, W = xs.shape (old)
+ # B, T, C, H, W (new)
+ xs_t = torch.transpose(xs, 1, 2)
+ feat = self.conv(xs_t)
+ if self.use_sigmoid:
+ feat = torch.sigmoid(feat)
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
+ return out
+
+
+def spectral_norm(module, mode=True):
+ if mode:
+ return _spectral_norm(module)
+ return module
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/model/e2fgvi_hq.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/e2fgvi_hq.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6bc78760ebc22ce52a80ee218e07985098abf7d
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/e2fgvi_hq.py
@@ -0,0 +1,350 @@
+''' Towards An End-to-End Framework for Video Inpainting
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from E2FGVI.model.modules.flow_comp import SPyNet
+from E2FGVI.model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
+from E2FGVI.model.modules.tfocal_transformer_hq import TemporalFocalTransformerBlock, SoftSplit, SoftComp
+from E2FGVI.model.modules.spectral_norm import spectral_norm as _spectral_norm
+
+
+class BaseNetwork(nn.Module):
+ def __init__(self):
+ super(BaseNetwork, self).__init__()
+
+ def print_network(self):
+ if isinstance(self, list):
+ self = self[0]
+ num_params = 0
+ for param in self.parameters():
+ num_params += param.numel()
+ print(
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
+ 'To see the architecture, do print(network).' %
+ (type(self).__name__, num_params / 1000000))
+
+ def init_weights(self, init_type='normal', gain=0.02):
+ '''
+ initialize network's weights
+ init_type: normal | xavier | kaiming | orthogonal
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
+ '''
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find('InstanceNorm2d') != -1:
+ if hasattr(m, 'weight') and m.weight is not None:
+ nn.init.constant_(m.weight.data, 1.0)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
+ or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ nn.init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'xavier_uniform':
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == 'kaiming':
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ nn.init.orthogonal_(m.weight.data, gain=gain)
+ elif init_type == 'none': # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError(
+ 'initialization method [%s] is not implemented' %
+ init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+
+ self.apply(init_func)
+
+ # propagate to children
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights(init_type, gain)
+
+
+class Encoder(nn.Module):
+ def __init__(self):
+ super(Encoder, self).__init__()
+ self.group = [1, 2, 4, 8, 1]
+ self.layers = nn.ModuleList([
+ nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True)
+ ])
+
+ def forward(self, x):
+ bt, c, _, _ = x.size()
+ # h, w = h//4, w//4
+ out = x
+ for i, layer in enumerate(self.layers):
+ if i == 8:
+ x0 = out
+ _, _, h, w = x0.size()
+ if i > 8 and i % 2 == 0:
+ g = self.group[(i - 8) // 2]
+ x = x0.view(bt, g, -1, h, w)
+ o = out.view(bt, g, -1, h, w)
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
+ out = layer(out)
+ return out
+
+
+class deconv(nn.Module):
+ def __init__(self,
+ input_channel,
+ output_channel,
+ kernel_size=3,
+ padding=0):
+ super().__init__()
+ self.conv = nn.Conv2d(input_channel,
+ output_channel,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding)
+
+ def forward(self, x):
+ x = F.interpolate(x,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=True)
+ return self.conv(x)
+
+
+class InpaintGenerator(BaseNetwork):
+ def __init__(self, init_weights=True):
+ super(InpaintGenerator, self).__init__()
+ channel = 256
+ hidden = 512
+
+ # encoder
+ self.encoder = Encoder()
+
+ # decoder
+ self.decoder = nn.Sequential(
+ deconv(channel // 2, 128, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ deconv(64, 64, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
+
+ # feature propagation module
+ self.feat_prop_module = BidirectionalPropagation(channel // 2)
+
+ # soft split and soft composition
+ kernel_size = (7, 7)
+ padding = (3, 3)
+ stride = (3, 3)
+ output_size = (60, 108)
+ t2t_params = {
+ 'kernel_size': kernel_size,
+ 'stride': stride,
+ 'padding': padding
+ }
+ self.ss = SoftSplit(channel // 2,
+ hidden,
+ kernel_size,
+ stride,
+ padding,
+ t2t_param=t2t_params)
+ self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding)
+
+ n_vecs = 1
+ for i, d in enumerate(kernel_size):
+ n_vecs *= int((output_size[i] + 2 * padding[i] -
+ (d - 1) - 1) / stride[i] + 1)
+
+ blocks = []
+ depths = 8
+ num_heads = [4] * depths
+ window_size = [(5, 9)] * depths
+ focal_windows = [(5, 9)] * depths
+ focal_levels = [2] * depths
+ pool_method = "fc"
+
+ for i in range(depths):
+ blocks.append(
+ TemporalFocalTransformerBlock(dim=hidden,
+ num_heads=num_heads[i],
+ window_size=window_size[i],
+ focal_level=focal_levels[i],
+ focal_window=focal_windows[i],
+ n_vecs=n_vecs,
+ t2t_params=t2t_params,
+ pool_method=pool_method))
+ self.transformer = nn.Sequential(*blocks)
+
+ if init_weights:
+ self.init_weights()
+ # Need to initial the weights of MSDeformAttn specifically
+ for m in self.modules():
+ if isinstance(m, SecondOrderDeformableAlignment):
+ m.init_offset()
+
+ # flow completion network
+ self.update_spynet = SPyNet()
+
+ def forward_bidirect_flow(self, masked_local_frames):
+ b, l_t, c, h, w = masked_local_frames.size()
+
+ # compute forward and backward flows of masked frames
+ masked_local_frames = F.interpolate(masked_local_frames.view(
+ -1, c, h, w),
+ scale_factor=1 / 4,
+ mode='bilinear',
+ align_corners=True,
+ recompute_scale_factor=True)
+ masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
+ w // 4)
+ mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
+ pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
+
+ pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
+ w // 4)
+ pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
+ w // 4)
+
+ return pred_flows_forward, pred_flows_backward
+
+ def forward(self, masked_frames, num_local_frames):
+ l_t = num_local_frames
+ b, t, ori_c, ori_h, ori_w = masked_frames.size()
+
+ # normalization before feeding into the flow completion module
+ masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
+ pred_flows = self.forward_bidirect_flow(masked_local_frames)
+
+ # extracting features and performing the feature propagation on local features
+ enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
+ _, c, h, w = enc_feat.size()
+ fold_output_size = (h, w)
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
+ local_feat = self.feat_prop_module(local_feat, pred_flows[0],
+ pred_flows[1])
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
+
+ # content hallucination through stacking multiple temporal focal transformer blocks
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size)
+ trans_feat = self.transformer([trans_feat, fold_output_size])
+ trans_feat = self.sc(trans_feat[0], t, fold_output_size)
+ trans_feat = trans_feat.view(b, t, -1, h, w)
+ enc_feat = enc_feat + trans_feat
+
+ # decode frames from features
+ output = self.decoder(enc_feat.view(b * t, c, h, w))
+ output = torch.tanh(output)
+ return output, pred_flows
+
+
+# ######################################################################
+# Discriminator for Temporal Patch GAN
+# ######################################################################
+
+
+class Discriminator(BaseNetwork):
+ def __init__(self,
+ in_channels=3,
+ use_sigmoid=False,
+ use_spectral_norm=True,
+ init_weights=True):
+ super(Discriminator, self).__init__()
+ self.use_sigmoid = use_sigmoid
+ nf = 32
+
+ self.conv = nn.Sequential(
+ spectral_norm(
+ nn.Conv3d(in_channels=in_channels,
+ out_channels=nf * 1,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=1,
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(64, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 1,
+ nf * 2,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(128, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 2,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2)))
+
+ if init_weights:
+ self.init_weights()
+
+ def forward(self, xs):
+ # T, C, H, W = xs.shape (old)
+ # B, T, C, H, W (new)
+ xs_t = torch.transpose(xs, 1, 2)
+ feat = self.conv(xs_t)
+ if self.use_sigmoid:
+ feat = torch.sigmoid(feat)
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
+ return out
+
+
+def spectral_norm(module, mode=True):
+ if mode:
+ return _spectral_norm(module)
+ return module
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/__init__.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/feat_prop.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/feat_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..3957a72e8e97c4f88c45da4fc12334c343073ce2
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/feat_prop.py
@@ -0,0 +1,149 @@
+"""
+ BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
+"""
+import torch
+import torch.nn as nn
+
+from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
+from mmcv.cnn import constant_init
+
+from E2FGVI.model.modules.flow_comp import flow_warp
+
+
+class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
+ """Second-order deformable alignment module."""
+ def __init__(self, *args, **kwargs):
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
+
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Sequential(
+ nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
+ )
+
+ self.init_offset()
+
+ def init_offset(self):
+ constant_init(self.conv_offset[-1], val=0, bias=0)
+
+ def forward(self, x, extra_feat, flow_1, flow_2):
+ extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
+ out = self.conv_offset(extra_feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+
+ # offset
+ offset = self.max_residue_magnitude * torch.tanh(
+ torch.cat((o1, o2), dim=1))
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
+ offset_1 = offset_1 + flow_1.flip(1).repeat(1,
+ offset_1.size(1) // 2, 1,
+ 1)
+ offset_2 = offset_2 + flow_2.flip(1).repeat(1,
+ offset_2.size(1) // 2, 1,
+ 1)
+ offset = torch.cat([offset_1, offset_2], dim=1)
+
+ # mask
+ mask = torch.sigmoid(mask)
+
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+
+
+class BidirectionalPropagation(nn.Module):
+ def __init__(self, channel):
+ super(BidirectionalPropagation, self).__init__()
+ modules = ['backward_', 'forward_']
+ self.deform_align = nn.ModuleDict()
+ self.backbone = nn.ModuleDict()
+ self.channel = channel
+
+ for i, module in enumerate(modules):
+ self.deform_align[module] = SecondOrderDeformableAlignment(
+ 2 * channel, channel, 3, padding=1, deform_groups=16)
+
+ self.backbone[module] = nn.Sequential(
+ nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(channel, channel, 3, 1, 1),
+ )
+
+ self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
+
+ def forward(self, x, flows_backward, flows_forward):
+ """
+ x shape : [b, t, c, h, w]
+ return [b, t, c, h, w]
+ """
+ b, t, c, h, w = x.shape
+ feats = {}
+ feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
+
+ for module_name in ['backward_', 'forward_']:
+
+ feats[module_name] = []
+
+ frame_idx = range(0, t)
+ flow_idx = range(-1, t - 1)
+ mapping_idx = list(range(0, len(feats['spatial'])))
+ mapping_idx += mapping_idx[::-1]
+
+ if 'backward' in module_name:
+ frame_idx = frame_idx[::-1]
+ flows = flows_backward
+ else:
+ flows = flows_forward
+
+ feat_prop = x.new_zeros(b, self.channel, h, w)
+ for i, idx in enumerate(frame_idx):
+ feat_current = feats['spatial'][mapping_idx[idx]]
+
+ if i > 0:
+ flow_n1 = flows[:, flow_idx[i], :, :, :]
+ cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
+
+ # initialize second-order features
+ feat_n2 = torch.zeros_like(feat_prop)
+ flow_n2 = torch.zeros_like(flow_n1)
+ cond_n2 = torch.zeros_like(cond_n1)
+ if i > 1:
+ feat_n2 = feats[module_name][-2]
+ flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
+ flow_n2 = flow_n1 + flow_warp(
+ flow_n2, flow_n1.permute(0, 2, 3, 1))
+ cond_n2 = flow_warp(feat_n2,
+ flow_n2.permute(0, 2, 3, 1))
+
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
+ feat_prop = self.deform_align[module_name](feat_prop, cond,
+ flow_n1,
+ flow_n2)
+
+ feat = [feat_current] + [
+ feats[k][idx]
+ for k in feats if k not in ['spatial', module_name]
+ ] + [feat_prop]
+
+ feat = torch.cat(feat, dim=1)
+ feat_prop = feat_prop + self.backbone[module_name](feat)
+ feats[module_name].append(feat_prop)
+
+ if 'backward' in module_name:
+ feats[module_name] = feats[module_name][::-1]
+
+ outputs = []
+ for i in range(0, t):
+ align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
+ align_feats = torch.cat(align_feats, dim=1)
+ outputs.append(self.fusion(align_feats))
+
+ return torch.stack(outputs, dim=1) + x
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/flow_comp.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/flow_comp.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33a8069e52803b9824798ee2b6602dfe560f83b
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/flow_comp.py
@@ -0,0 +1,450 @@
+import numpy as np
+
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+
+from mmcv.cnn import ConvModule
+from mmcv.runner import load_checkpoint
+
+
+class FlowCompletionLoss(nn.Module):
+ """Flow completion loss"""
+ def __init__(self):
+ super().__init__()
+ self.fix_spynet = SPyNet()
+ for p in self.fix_spynet.parameters():
+ p.requires_grad = False
+
+ self.l1_criterion = nn.L1Loss()
+
+ def forward(self, pred_flows, gt_local_frames):
+ b, l_t, c, h, w = gt_local_frames.size()
+
+ with torch.no_grad():
+ # compute gt forward and backward flows
+ gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w),
+ scale_factor=1 / 4,
+ mode='bilinear',
+ align_corners=True,
+ recompute_scale_factor=True)
+ gt_local_frames = gt_local_frames.view(b, l_t, c, h // 4, w // 4)
+ gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(
+ -1, c, h // 4, w // 4)
+ gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2)
+ gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1)
+
+ # calculate loss for flow completion
+ forward_flow_loss = self.l1_criterion(
+ pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward)
+ backward_flow_loss = self.l1_criterion(
+ pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward)
+ flow_loss = forward_flow_loss + backward_flow_loss
+
+ return flow_loss
+
+
+class SPyNet(nn.Module):
+ """SPyNet network structure.
+ The difference to the SPyNet in [tof.py] is that
+ 1. more SPyNetBasicModule is used in this version, and
+ 2. no batch normalization is used in this version.
+ Paper:
+ Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
+ Args:
+ pretrained (str): path for pre-trained SPyNet. Default: None.
+ """
+ def __init__(
+ self,
+ use_pretrain=True,
+ pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth'
+ ):
+ super().__init__()
+
+ self.basic_module = nn.ModuleList(
+ [SPyNetBasicModule() for _ in range(6)])
+
+ if use_pretrain:
+ if isinstance(pretrained, str):
+ print("load pretrained SPyNet...")
+ load_checkpoint(self, pretrained, strict=True)
+ elif pretrained is not None:
+ raise TypeError('[pretrained] should be str or None, '
+ f'but got {type(pretrained)}.')
+
+ self.register_buffer(
+ 'mean',
+ torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer(
+ 'std',
+ torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def compute_flow(self, ref, supp):
+ """Compute flow from ref to supp.
+ Note that in this function, the images are already resized to a
+ multiple of 32.
+ Args:
+ ref (Tensor): Reference image with shape of (n, 3, h, w).
+ supp (Tensor): Supporting image with shape of (n, 3, h, w).
+ Returns:
+ Tensor: Estimated optical flow: (n, 2, h, w).
+ """
+ n, _, h, w = ref.size()
+
+ # normalize the input images
+ ref = [(ref - self.mean) / self.std]
+ supp = [(supp - self.mean) / self.std]
+
+ # generate downsampled frames
+ for level in range(5):
+ ref.append(
+ F.avg_pool2d(input=ref[-1],
+ kernel_size=2,
+ stride=2,
+ count_include_pad=False))
+ supp.append(
+ F.avg_pool2d(input=supp[-1],
+ kernel_size=2,
+ stride=2,
+ count_include_pad=False))
+ ref = ref[::-1]
+ supp = supp[::-1]
+
+ # flow computation
+ flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
+ for level in range(len(ref)):
+ if level == 0:
+ flow_up = flow
+ else:
+ flow_up = F.interpolate(input=flow,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=True) * 2.0
+
+ # add the residue to the upsampled flow
+ flow = flow_up + self.basic_module[level](torch.cat([
+ ref[level],
+ flow_warp(supp[level],
+ flow_up.permute(0, 2, 3, 1).contiguous(),
+ padding_mode='border'), flow_up
+ ], 1))
+
+ return flow
+
+ def forward(self, ref, supp):
+ """Forward function of SPyNet.
+ This function computes the optical flow from ref to supp.
+ Args:
+ ref (Tensor): Reference image with shape of (n, 3, h, w).
+ supp (Tensor): Supporting image with shape of (n, 3, h, w).
+ Returns:
+ Tensor: Estimated optical flow: (n, 2, h, w).
+ """
+
+ # upsize to a multiple of 32
+ h, w = ref.shape[2:4]
+ w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
+ h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
+ ref = F.interpolate(input=ref,
+ size=(h_up, w_up),
+ mode='bilinear',
+ align_corners=False)
+ supp = F.interpolate(input=supp,
+ size=(h_up, w_up),
+ mode='bilinear',
+ align_corners=False)
+
+ # compute flow, and resize back to the original resolution
+ flow = F.interpolate(input=self.compute_flow(ref, supp),
+ size=(h, w),
+ mode='bilinear',
+ align_corners=False)
+
+ # adjust the flow values
+ flow[:, 0, :, :] *= float(w) / float(w_up)
+ flow[:, 1, :, :] *= float(h) / float(h_up)
+
+ return flow
+
+
+class SPyNetBasicModule(nn.Module):
+ """Basic Module for SPyNet.
+ Paper:
+ Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
+ """
+ def __init__(self):
+ super().__init__()
+
+ self.basic_module = nn.Sequential(
+ ConvModule(in_channels=8,
+ out_channels=32,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU')),
+ ConvModule(in_channels=32,
+ out_channels=64,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU')),
+ ConvModule(in_channels=64,
+ out_channels=32,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU')),
+ ConvModule(in_channels=32,
+ out_channels=16,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU')),
+ ConvModule(in_channels=16,
+ out_channels=2,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ norm_cfg=None,
+ act_cfg=None))
+
+ def forward(self, tensor_input):
+ """
+ Args:
+ tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
+ 8 channels contain:
+ [reference image (3), neighbor image (3), initial flow (2)].
+ Returns:
+ Tensor: Refined flow with shape (b, 2, h, w)
+ """
+ return self.basic_module(tensor_input)
+
+
+# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
+def make_colorwheel():
+ """
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+
+ Code follows the original C++ source code of Daniel Scharstein.
+ Code follows the the Matlab source code of Deqing Sun.
+
+ Returns:
+ np.ndarray: Color wheel
+ """
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
+ col = col + RY
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
+ colorwheel[col:col + YG, 1] = 255
+ col = col + YG
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
+ col = col + GC
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
+ colorwheel[col:col + CB, 2] = 255
+ col = col + CB
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
+ col = col + BM
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
+ colorwheel[col:col + MR, 0] = 255
+ return colorwheel
+
+
+def flow_uv_to_colors(u, v, convert_to_bgr=False):
+ """
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+
+ Args:
+ u (np.ndarray): Input horizontal flow of shape [H,W]
+ v (np.ndarray): Input vertical flow of shape [H,W]
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+ colorwheel = make_colorwheel() # shape [55x3]
+ ncols = colorwheel.shape[0]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ a = np.arctan2(-v, -u) / np.pi
+ fk = (a + 1) / 2 * (ncols - 1)
+ k0 = np.floor(fk).astype(np.int32)
+ k1 = k0 + 1
+ k1[k1 == ncols] = 0
+ f = fk - k0
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1 - f) * col0 + f * col1
+ idx = (rad <= 1)
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2 - i if convert_to_bgr else i
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
+ return flow_image
+
+
+def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
+ """
+ Expects a two dimensional flow image of shape.
+
+ Args:
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
+ u = flow_uv[:, :, 0]
+ v = flow_uv[:, :, 1]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ rad_max = np.max(rad)
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+ return flow_uv_to_colors(u, v, convert_to_bgr)
+
+
+def flow_warp(x,
+ flow,
+ interpolation='bilinear',
+ padding_mode='zeros',
+ align_corners=True):
+ """Warp an image or a feature map with optical flow.
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
+ a two-channel, denoting the width and height relative offsets.
+ Note that the values are not normalized to [-1, 1].
+ interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
+ Default: 'bilinear'.
+ padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Whether align corners. Default: True.
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ if x.size()[-2:] != flow.size()[1:3]:
+ raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
+ f'flow ({flow.size()[1:3]}) are not the same.')
+ _, _, h, w = x.size()
+ # create mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
+ grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
+ grid.requires_grad = False
+
+ grid_flow = grid + flow
+ # scale grid_flow to [-1,1]
+ grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
+ grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
+ grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
+ output = F.grid_sample(x,
+ grid_flow,
+ mode=interpolation,
+ padding_mode=padding_mode,
+ align_corners=align_corners)
+ return output
+
+
+def initial_mask_flow(mask):
+ """
+ mask 1 indicates valid pixel 0 indicates unknown pixel
+ """
+ B, T, C, H, W = mask.shape
+
+ # calculate relative position
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
+
+ grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask)
+ abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :])
+ relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :])
+
+ abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None])
+ relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None])
+
+ # calculate the nearest indices
+ pos_up = mask.unsqueeze(3).repeat(
+ 1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * (
+ relative_pos_y <= H)[None, None, None]
+ nearest_indice_up = pos_up.max(dim=4)[1]
+
+ pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[
+ None, None, None] * (relative_pos_y <= H)[None, None, None]
+ nearest_indice_down = (pos_down).max(dim=4)[1]
+
+ pos_left = mask.unsqueeze(4).repeat(
+ 1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * (
+ relative_pos_x <= W)[None, None, None]
+ nearest_indice_left = (pos_left).max(dim=5)[1]
+
+ pos_right = mask.unsqueeze(4).repeat(
+ 1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * (
+ relative_pos_x <= W)[None, None, None]
+ nearest_indice_right = (pos_right).max(dim=5)[1]
+
+ # NOTE: IMPORTANT !!! depending on how to use this offset
+ initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3)
+ initial_offset_down = nearest_indice_down - grid_y[None, None, None]
+
+ initial_offset_left = -(nearest_indice_left -
+ grid_x[None, None, None]).flip(4)
+ initial_offset_right = nearest_indice_right - grid_x[None, None, None]
+
+ # nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1]
+ # initial_offset_x = nearest_indice_x - grid_x
+
+ # handle the boundary cases
+ final_offset_down = (initial_offset_down < 0) * initial_offset_up + (
+ initial_offset_down > 0) * initial_offset_down
+ final_offset_up = (initial_offset_up > 0) * initial_offset_down + (
+ initial_offset_up < 0) * initial_offset_up
+ final_offset_right = (initial_offset_right < 0) * initial_offset_left + (
+ initial_offset_right > 0) * initial_offset_right
+ final_offset_left = (initial_offset_left > 0) * initial_offset_right + (
+ initial_offset_left < 0) * initial_offset_left
+ zero_offset = torch.zeros_like(final_offset_down)
+ # out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2)
+ out = torch.cat([
+ zero_offset, final_offset_left, zero_offset, final_offset_right,
+ final_offset_up, zero_offset, final_offset_down, zero_offset
+ ],
+ dim=2)
+
+ return out
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/spectral_norm.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/spectral_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f38c34e98c03caa28ce0b15a4083215fb7d8e9af
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/spectral_norm.py
@@ -0,0 +1,288 @@
+"""
+Spectral Normalization from https://arxiv.org/abs/1802.05957
+"""
+import torch
+from torch.nn.functional import normalize
+
+
+class SpectralNorm(object):
+ # Invariant before and after each forward call:
+ # u = normalize(W @ v)
+ # NB: At initialization, this invariant is not enforced
+
+ _version = 1
+
+ # At version 1:
+ # made `W` not a buffer,
+ # added `v` as a buffer, and
+ # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
+
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
+ self.name = name
+ self.dim = dim
+ if n_power_iterations <= 0:
+ raise ValueError(
+ 'Expected n_power_iterations to be positive, but '
+ 'got n_power_iterations={}'.format(n_power_iterations))
+ self.n_power_iterations = n_power_iterations
+ self.eps = eps
+
+ def reshape_weight_to_matrix(self, weight):
+ weight_mat = weight
+ if self.dim != 0:
+ # permute dim to front
+ weight_mat = weight_mat.permute(
+ self.dim,
+ *[d for d in range(weight_mat.dim()) if d != self.dim])
+ height = weight_mat.size(0)
+ return weight_mat.reshape(height, -1)
+
+ def compute_weight(self, module, do_power_iteration):
+ # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
+ # updated in power iteration **in-place**. This is very important
+ # because in `DataParallel` forward, the vectors (being buffers) are
+ # broadcast from the parallelized module to each module replica,
+ # which is a new module object created on the fly. And each replica
+ # runs its own spectral norm power iteration. So simply assigning
+ # the updated vectors to the module this function runs on will cause
+ # the update to be lost forever. And the next time the parallelized
+ # module is replicated, the same randomly initialized vectors are
+ # broadcast and used!
+ #
+ # Therefore, to make the change propagate back, we rely on two
+ # important behaviors (also enforced via tests):
+ # 1. `DataParallel` doesn't clone storage if the broadcast tensor
+ # is already on correct device; and it makes sure that the
+ # parallelized module is already on `device[0]`.
+ # 2. If the out tensor in `out=` kwarg has correct shape, it will
+ # just fill in the values.
+ # Therefore, since the same power iteration is performed on all
+ # devices, simply updating the tensors in-place will make sure that
+ # the module replica on `device[0]` will update the _u vector on the
+ # parallized module (by shared storage).
+ #
+ # However, after we update `u` and `v` in-place, we need to **clone**
+ # them before using them to normalize the weight. This is to support
+ # backproping through two forward passes, e.g., the common pattern in
+ # GAN training: loss = D(real) - D(fake). Otherwise, engine will
+ # complain that variables needed to do backward for the first forward
+ # (i.e., the `u` and `v` vectors) are changed in the second forward.
+ weight = getattr(module, self.name + '_orig')
+ u = getattr(module, self.name + '_u')
+ v = getattr(module, self.name + '_v')
+ weight_mat = self.reshape_weight_to_matrix(weight)
+
+ if do_power_iteration:
+ with torch.no_grad():
+ for _ in range(self.n_power_iterations):
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
+ # are the first left and right singular vectors.
+ # This power iteration produces approximations of `u` and `v`.
+ v = normalize(torch.mv(weight_mat.t(), u),
+ dim=0,
+ eps=self.eps,
+ out=v)
+ u = normalize(torch.mv(weight_mat, v),
+ dim=0,
+ eps=self.eps,
+ out=u)
+ if self.n_power_iterations > 0:
+ # See above on why we need to clone
+ u = u.clone()
+ v = v.clone()
+
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
+ weight = weight / sigma
+ return weight
+
+ def remove(self, module):
+ with torch.no_grad():
+ weight = self.compute_weight(module, do_power_iteration=False)
+ delattr(module, self.name)
+ delattr(module, self.name + '_u')
+ delattr(module, self.name + '_v')
+ delattr(module, self.name + '_orig')
+ module.register_parameter(self.name,
+ torch.nn.Parameter(weight.detach()))
+
+ def __call__(self, module, inputs):
+ setattr(
+ module, self.name,
+ self.compute_weight(module, do_power_iteration=module.training))
+
+ def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
+ # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
+ # (the invariant at top of this class) and `u @ W @ v = sigma`.
+ # This uses pinverse in case W^T W is not invertible.
+ v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
+ weight_mat.t(), u.unsqueeze(1)).squeeze(1)
+ return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
+
+ @staticmethod
+ def apply(module, name, n_power_iterations, dim, eps):
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ raise RuntimeError(
+ "Cannot register two spectral_norm hooks on "
+ "the same parameter {}".format(name))
+
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
+ weight = module._parameters[name]
+
+ with torch.no_grad():
+ weight_mat = fn.reshape_weight_to_matrix(weight)
+
+ h, w = weight_mat.size()
+ # randomly initialize `u` and `v`
+ u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
+ v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
+
+ delattr(module, fn.name)
+ module.register_parameter(fn.name + "_orig", weight)
+ # We still need to assign weight back as fn.name because all sorts of
+ # things may assume that it exists, e.g., when initializing weights.
+ # However, we can't directly assign as it could be an nn.Parameter and
+ # gets added as a parameter. Instead, we register weight.data as a plain
+ # attribute.
+ setattr(module, fn.name, weight.data)
+ module.register_buffer(fn.name + "_u", u)
+ module.register_buffer(fn.name + "_v", v)
+
+ module.register_forward_pre_hook(fn)
+
+ module._register_state_dict_hook(SpectralNormStateDictHook(fn))
+ module._register_load_state_dict_pre_hook(
+ SpectralNormLoadStateDictPreHook(fn))
+ return fn
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormLoadStateDictPreHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ # For state_dict with version None, (assuming that it has gone through at
+ # least one training forward), we have
+ #
+ # u = normalize(W_orig @ v)
+ # W = W_orig / sigma, where sigma = u @ W_orig @ v
+ #
+ # To compute `v`, we solve `W_orig @ x = u`, and let
+ # v = x / (u @ W_orig @ x) * (W / W_orig).
+ def __call__(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ fn = self.fn
+ version = local_metadata.get('spectral_norm',
+ {}).get(fn.name + '.version', None)
+ if version is None or version < 1:
+ with torch.no_grad():
+ weight_orig = state_dict[prefix + fn.name + '_orig']
+ # weight = state_dict.pop(prefix + fn.name)
+ # sigma = (weight_orig / weight).mean()
+ weight_mat = fn.reshape_weight_to_matrix(weight_orig)
+ u = state_dict[prefix + fn.name + '_u']
+ # v = fn._solve_v_and_rescale(weight_mat, u, sigma)
+ # state_dict[prefix + fn.name + '_v'] = v
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormStateDictHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ def __call__(self, module, state_dict, prefix, local_metadata):
+ if 'spectral_norm' not in local_metadata:
+ local_metadata['spectral_norm'] = {}
+ key = self.fn.name + '.version'
+ if key in local_metadata['spectral_norm']:
+ raise RuntimeError(
+ "Unexpected key in metadata['spectral_norm']: {}".format(key))
+ local_metadata['spectral_norm'][key] = self.fn._version
+
+
+def spectral_norm(module,
+ name='weight',
+ n_power_iterations=1,
+ eps=1e-12,
+ dim=None):
+ r"""Applies spectral normalization to a parameter in the given module.
+
+ .. math::
+ \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
+
+ Spectral normalization stabilizes the training of discriminators (critics)
+ in Generative Adversarial Networks (GANs) by rescaling the weight tensor
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
+ power iteration method. If the dimension of the weight tensor is greater
+ than 2, it is reshaped to 2D in power iteration method to get spectral
+ norm. This is implemented via a hook that calculates spectral norm and
+ rescales weight before every :meth:`~Module.forward` call.
+
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
+
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
+
+ Args:
+ module (nn.Module): containing module
+ name (str, optional): name of weight parameter
+ n_power_iterations (int, optional): number of power iterations to
+ calculate spectral norm
+ eps (float, optional): epsilon for numerical stability in
+ calculating norms
+ dim (int, optional): dimension corresponding to number of outputs,
+ the default is ``0``, except for modules that are instances of
+ ConvTranspose{1,2,3}d, when it is ``1``
+
+ Returns:
+ The original module with the spectral norm hook
+
+ Example::
+
+ >>> m = spectral_norm(nn.Linear(20, 40))
+ >>> m
+ Linear(in_features=20, out_features=40, bias=True)
+ >>> m.weight_u.size()
+ torch.Size([40])
+
+ """
+ if dim is None:
+ if isinstance(module,
+ (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d)):
+ dim = 1
+ else:
+ dim = 0
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
+ return module
+
+
+def remove_spectral_norm(module, name='weight'):
+ r"""Removes the spectral normalization reparameterization from a module.
+
+ Args:
+ module (Module): containing module
+ name (str, optional): name of weight parameter
+
+ Example:
+ >>> m = spectral_norm(nn.Linear(40, 10))
+ >>> remove_spectral_norm(m)
+ """
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+ return module
+
+ raise ValueError("spectral_norm of '{}' not found in {}".format(
+ name, module))
+
+
+def use_spectral_norm(module, use_sn=False):
+ if use_sn:
+ return spectral_norm(module)
+ return module
\ No newline at end of file
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/tfocal_transformer.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/tfocal_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..179508f490f2662331a8817b37513005e98fe4de
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/tfocal_transformer.py
@@ -0,0 +1,536 @@
+"""
+ This code is based on:
+ [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
+ https://github.com/ruiliu-ai/FuseFormer
+ [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
+ https://github.com/yitu-opensource/T2T-ViT
+ [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
+ https://github.com/microsoft/Focal-Transformer
+"""
+
+import math
+from functools import reduce
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SoftSplit(nn.Module):
+ def __init__(self, channel, hidden, kernel_size, stride, padding,
+ t2t_param):
+ super(SoftSplit, self).__init__()
+ self.kernel_size = kernel_size
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(c_in, hidden)
+
+ self.f_h = int(
+ (t2t_param['output_size'][0] + 2 * t2t_param['padding'][0] -
+ (t2t_param['kernel_size'][0] - 1) - 1) / t2t_param['stride'][0] +
+ 1)
+ self.f_w = int(
+ (t2t_param['output_size'][1] + 2 * t2t_param['padding'][1] -
+ (t2t_param['kernel_size'][1] - 1) - 1) / t2t_param['stride'][1] +
+ 1)
+
+ def forward(self, x, b):
+ feat = self.t2t(x)
+ feat = feat.permute(0, 2, 1)
+ # feat shape [b*t, num_vec, ks*ks*c]
+ feat = self.embedding(feat)
+ # feat shape after embedding [b, t*num_vec, hidden]
+ feat = feat.view(b, -1, self.f_h, self.f_w, feat.size(2))
+ return feat
+
+
+class SoftComp(nn.Module):
+ def __init__(self, channel, hidden, output_size, kernel_size, stride,
+ padding):
+ super(SoftComp, self).__init__()
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(hidden, c_out)
+ self.t2t = torch.nn.Fold(output_size=output_size,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ h, w = output_size
+ self.bias = nn.Parameter(torch.zeros((channel, h, w),
+ dtype=torch.float32),
+ requires_grad=True)
+
+ def forward(self, x, t):
+ b_, _, _, _, c_ = x.shape
+ x = x.view(b_, -1, c_)
+ feat = self.embedding(x)
+ b, _, c = feat.size()
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
+ feat = self.t2t(feat) + self.bias[None]
+ return feat
+
+
+class FusionFeedForward(nn.Module):
+ def __init__(self, d_model, n_vecs=None, t2t_params=None):
+ super(FusionFeedForward, self).__init__()
+ # We set d_ff as a default to 1960
+ hd = 1960
+ self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
+ self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
+ assert t2t_params is not None and n_vecs is not None
+ tp = t2t_params.copy()
+ self.fold = nn.Fold(**tp)
+ del tp['output_size']
+ self.unfold = nn.Unfold(**tp)
+ self.n_vecs = n_vecs
+
+ def forward(self, x):
+ x = self.conv1(x)
+ b, n, c = x.size()
+ normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs,
+ 49).permute(0, 2, 1)
+ x = self.unfold(
+ self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) /
+ self.fold(normalizer)).permute(0, 2, 1).contiguous().view(b, n, c)
+ x = self.conv2(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: shape is (B, T, H, W, C)
+ window_size (tuple[int]): window size
+ Returns:
+ windows: (B*num_windows, T*window_size*window_size, C)
+ """
+ B, T, H, W, C = x.shape
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
+ window_size[1], C)
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
+ -1, T * window_size[0] * window_size[1], C)
+ return windows
+
+
+def window_partition_noreshape(x, window_size):
+ """
+ Args:
+ x: shape is (B, T, H, W, C)
+ window_size (tuple[int]): window size
+ Returns:
+ windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
+ """
+ B, T, H, W, C = x.shape
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
+ window_size[1], C)
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
+ return windows
+
+
+def window_reverse(windows, window_size, T, H, W):
+ """
+ Args:
+ windows: shape is (num_windows*B, T, window_size, window_size, C)
+ window_size (tuple[int]): Window size
+ T (int): Temporal length of video
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, T, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
+ x = windows.view(B, H // window_size[0], W // window_size[1], T,
+ window_size[0], window_size[1], -1)
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """Temporal focal window attention
+ """
+ def __init__(self, dim, expand_size, window_size, focal_window,
+ focal_level, num_heads, qkv_bias, pool_method):
+
+ super().__init__()
+ self.dim = dim
+ self.expand_size = expand_size
+ self.window_size = window_size # Wh, Ww
+ self.pool_method = pool_method
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ if any(i > 0 for i in self.expand_size) and focal_level > 0:
+ # get mask for rolled k and rolled v
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
+ 0).flatten(0)
+ self.register_buffer("valid_ind_rolled",
+ mask_rolled.nonzero(as_tuple=False).view(-1))
+
+ if pool_method != "none" and focal_level > 1:
+ self.unfolds = nn.ModuleList()
+
+ # build relative position bias between local patch and pooled windows
+ for k in range(focal_level - 1):
+ stride = 2**k
+ kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
+ for i in self.focal_window)
+ # define unfolding operations
+ self.unfolds += [
+ nn.Unfold(kernel_size=kernel_size,
+ stride=stride,
+ padding=tuple(i // 2 for i in kernel_size))
+ ]
+
+ # define unfolding index for focal_level > 0
+ if k > 0:
+ mask = torch.zeros(kernel_size)
+ mask[(2**k) - 1:, (2**k) - 1:] = 1
+ self.register_buffer(
+ "valid_ind_unfold_{}".format(k),
+ mask.flatten(0).nonzero(as_tuple=False).view(-1))
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x_all, mask_all=None):
+ """
+ Args:
+ x: input features with shape of (B, T, Wh, Ww, C)
+ mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
+
+ output: (nW*B, Wh*Ww, C)
+ """
+ x = x_all[0]
+
+ B, T, nH, nW, C = x.shape
+ qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
+ C).permute(4, 0, 1, 2, 3, 5).contiguous()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
+
+ # partition q map
+ (q_windows, k_windows, v_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
+ contiguous().view(-1, self.num_heads, T * self.window_size[
+ 0] * self.window_size[1], C // self.num_heads), (q, k, v))
+ # q(k/v)_windows shape : [16, 4, 225, 128]
+
+ if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
+ (k_tl, v_tl) = map(
+ lambda t: torch.roll(t,
+ shifts=(-self.expand_size[0], -self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_tr, v_tr) = map(
+ lambda t: torch.roll(t,
+ shifts=(-self.expand_size[0], self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_bl, v_bl) = map(
+ lambda t: torch.roll(t,
+ shifts=(self.expand_size[0], -self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_br, v_br) = map(
+ lambda t: torch.roll(t,
+ shifts=(self.expand_size[0], self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
+ k_rolled = torch.cat(
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
+ 2).permute(0, 3, 1, 2, 4).contiguous()
+ v_rolled = torch.cat(
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
+ 2).permute(0, 3, 1, 2, 4).contiguous()
+
+ # mask out tokens in current window
+ k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
+ v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
+ temp_N = k_rolled.shape[3]
+ k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
+ C // self.num_heads)
+ v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
+ C // self.num_heads)
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
+ else:
+ k_rolled = k_windows
+ v_rolled = v_windows
+
+ # q(k/v)_windows shape : [16, 4, 225, 128]
+ # k_rolled.shape : [16, 4, 5, 165, 128]
+ # ideal expanded window size 153 ((5+2*2)*(9+2*4))
+ # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ k_pooled = []
+ v_pooled = []
+ for k in range(self.focal_level - 1):
+ stride = 2**k
+ x_window_pooled = x_all[k + 1].permute(
+ 0, 3, 1, 2, 4).contiguous() # B, T, nWh, nWw, C
+
+ nWh, nWw = x_window_pooled.shape[2:4]
+
+ # generate mask for pooled windows
+ mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
+ # unfold mask: [nWh*nWw//s//s, k*k, 1]
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
+ 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
+ view(nWh*nWw // stride // stride, -1, 1)
+
+ if k > 0:
+ valid_ind_unfold_k = getattr(
+ self, "valid_ind_unfold_{}".format(k))
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
+
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
+ x_window_masks = x_window_masks.masked_fill(
+ x_window_masks == 0,
+ float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
+ mask_all[k + 1] = x_window_masks
+
+ # generate k and v for pooled windows
+ qkv_pooled = self.qkv(x_window_pooled).reshape(
+ B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
+ 3).view(3, -1, C, nWh,
+ nWw).contiguous()
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[
+ 2] # B*T, C, nWh, nWw
+ # k_pooled_k shape: [5, 512, 4, 4]
+ # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
+
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: self.unfolds[k](t).view(
+ B, T, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 5, 1, 3, 4, 2).contiguous().\
+ view(-1, T, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).contiguous(),
+ (k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
+ )
+ # k_pooled_k shape : [16, 4, 5, 45, 128]
+
+ # select valid unfolding index
+ if k > 0:
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: t[:, :, :, valid_ind_unfold_k],
+ (k_pooled_k, v_pooled_k))
+
+ k_pooled_k = k_pooled_k.view(
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
+ v_pooled_k = v_pooled_k.view(
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
+
+ k_pooled += [k_pooled_k]
+ v_pooled += [v_pooled_k]
+
+ # k_all (v_all) shape : [16, 4, 5 * 210, 128]
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
+ else:
+ k_all = k_rolled
+ v_all = v_rolled
+
+ N = k_all.shape[-2]
+ q_windows = q_windows * self.scale
+ attn = (
+ q_windows @ k_all.transpose(-2, -1)
+ ) # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
+ # T * 45
+ window_area = T * self.window_size[0] * self.window_size[1]
+ # T * 165
+ window_area_rolled = k_rolled.shape[2]
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ offset = window_area_rolled
+ for k in range(self.focal_level - 1):
+ # add attentional mask
+ # mask_all[1] shape [1, 16, T * 45]
+
+ bias = tuple((i + 2**k - 1) for i in self.focal_window)
+
+ if mask_all[k + 1] is not None:
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
+ mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
+
+ offset += T * bias[0] * bias[1]
+
+ if mask_all[0] is not None:
+ nW = mask_all[0].shape[0]
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
+ window_area, N)
+ attn[:, :, :, :, :
+ window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
+ None, :, None, :, :]
+ attn = attn.view(-1, self.num_heads, window_area, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
+ C)
+ x = self.proj(x)
+ return x
+
+
+class TemporalFocalTransformerBlock(nn.Module):
+ r""" Temporal Focal Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ focal_level (int): The number level of focal window.
+ focal_window (int): Window size of each focal window.
+ n_vecs (int): Required for F3N.
+ t2t_params (int): T2T parameters for F3N.
+ """
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(5, 9),
+ mlp_ratio=4.,
+ qkv_bias=True,
+ pool_method="fc",
+ focal_level=2,
+ focal_window=(5, 9),
+ norm_layer=nn.LayerNorm,
+ n_vecs=None,
+ t2t_params=None):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.expand_size = tuple(i // 2 for i in window_size) # TODO
+ self.mlp_ratio = mlp_ratio
+ self.pool_method = pool_method
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ self.window_size_glo = self.window_size
+
+ self.pool_layers = nn.ModuleList()
+ if self.pool_method != "none":
+ for k in range(self.focal_level - 1):
+ window_size_glo = tuple(
+ math.floor(i / (2**k)) for i in self.window_size_glo)
+ self.pool_layers.append(
+ nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
+ self.pool_layers[-1].weight.data.fill_(
+ 1. / (window_size_glo[0] * window_size_glo[1]))
+ self.pool_layers[-1].bias.data.fill_(0)
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = WindowAttention(dim,
+ expand_size=self.expand_size,
+ window_size=self.window_size,
+ focal_window=focal_window,
+ focal_level=focal_level,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ pool_method=pool_method)
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
+
+ def forward(self, x):
+ B, T, H, W, C = x.shape
+
+ shortcut = x
+ x = self.norm1(x)
+
+ shifted_x = x
+
+ x_windows_all = [shifted_x]
+ x_window_masks_all = [None]
+
+ # partition windows tuple(i // 2 for i in window_size)
+ if self.focal_level > 1 and self.pool_method != "none":
+ # if we add coarser granularity and the pool method is not none
+ for k in range(self.focal_level - 1):
+ window_size_glo = tuple(
+ math.floor(i / (2**k)) for i in self.window_size_glo)
+ pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
+ pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
+ H_pool = pooled_h * window_size_glo[0]
+ W_pool = pooled_w * window_size_glo[1]
+
+ x_level_k = shifted_x
+ # trim or pad shifted_x depending on the required size
+ if H > H_pool:
+ trim_t = (H - H_pool) // 2
+ trim_b = H - H_pool - trim_t
+ x_level_k = x_level_k[:, :, trim_t:-trim_b]
+ elif H < H_pool:
+ pad_t = (H_pool - H) // 2
+ pad_b = H_pool - H - pad_t
+ x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
+
+ if W > W_pool:
+ trim_l = (W - W_pool) // 2
+ trim_r = W - W_pool - trim_l
+ x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
+ elif W < W_pool:
+ pad_l = (W_pool - W) // 2
+ pad_r = W_pool - W - pad_l
+ x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
+
+ x_windows_noreshape = window_partition_noreshape(
+ x_level_k.contiguous(), window_size_glo
+ ) # B, nw, nw, T, window_size, window_size, C
+ nWh, nWw = x_windows_noreshape.shape[1:3]
+ x_windows_noreshape = x_windows_noreshape.view(
+ B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
+ C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
+ x_windows_pooled = self.pool_layers[k](
+ x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
+
+ x_windows_all += [x_windows_pooled]
+ x_window_masks_all += [None]
+
+ attn_windows = self.attn(
+ x_windows_all,
+ mask_all=x_window_masks_all) # nW*B, T*window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, T, self.window_size[0],
+ self.window_size[1], C)
+ shifted_x = window_reverse(attn_windows, self.window_size, T, H,
+ W) # B T H' W' C
+
+ # FFN
+ x = shortcut + shifted_x
+ y = self.norm2(x)
+ x = x + self.mlp(y.view(B, T * H * W, C)).view(B, T, H, W, C)
+
+ return x
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/tfocal_transformer_hq.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/tfocal_transformer_hq.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a24dfa799533ff96bfb94b01ad8593f45bb590f
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/model/modules/tfocal_transformer_hq.py
@@ -0,0 +1,565 @@
+"""
+ This code is based on:
+ [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
+ https://github.com/ruiliu-ai/FuseFormer
+ [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
+ https://github.com/yitu-opensource/T2T-ViT
+ [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
+ https://github.com/microsoft/Focal-Transformer
+"""
+
+import math
+from functools import reduce
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SoftSplit(nn.Module):
+ def __init__(self, channel, hidden, kernel_size, stride, padding,
+ t2t_param):
+ super(SoftSplit, self).__init__()
+ self.kernel_size = kernel_size
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(c_in, hidden)
+
+ self.t2t_param = t2t_param
+
+ def forward(self, x, b, output_size):
+ f_h = int((output_size[0] + 2 * self.t2t_param['padding'][0] -
+ (self.t2t_param['kernel_size'][0] - 1) - 1) /
+ self.t2t_param['stride'][0] + 1)
+ f_w = int((output_size[1] + 2 * self.t2t_param['padding'][1] -
+ (self.t2t_param['kernel_size'][1] - 1) - 1) /
+ self.t2t_param['stride'][1] + 1)
+
+ feat = self.t2t(x)
+ feat = feat.permute(0, 2, 1)
+ # feat shape [b*t, num_vec, ks*ks*c]
+ feat = self.embedding(feat)
+ # feat shape after embedding [b, t*num_vec, hidden]
+ feat = feat.view(b, -1, f_h, f_w, feat.size(2))
+ return feat
+
+
+class SoftComp(nn.Module):
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
+ super(SoftComp, self).__init__()
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(hidden, c_out)
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.bias_conv = nn.Conv2d(channel,
+ channel,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ # TODO upsample conv
+ # self.bias_conv = nn.Conv2d()
+ # self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True)
+
+ def forward(self, x, t, output_size):
+ b_, _, _, _, c_ = x.shape
+ x = x.view(b_, -1, c_)
+ feat = self.embedding(x)
+ b, _, c = feat.size()
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
+ feat = F.fold(feat,
+ output_size=output_size,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding)
+ feat = self.bias_conv(feat)
+ return feat
+
+
+class FusionFeedForward(nn.Module):
+ def __init__(self, d_model, n_vecs=None, t2t_params=None):
+ super(FusionFeedForward, self).__init__()
+ # We set d_ff as a default to 1960
+ hd = 1960
+ self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
+ self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
+ assert t2t_params is not None and n_vecs is not None
+ self.t2t_params = t2t_params
+
+ def forward(self, x, output_size):
+ n_vecs = 1
+ for i, d in enumerate(self.t2t_params['kernel_size']):
+ n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
+ (d - 1) - 1) / self.t2t_params['stride'][i] + 1)
+
+ x = self.conv1(x)
+ b, n, c = x.size()
+ normalizer = x.new_ones(b, n, 49).view(-1, n_vecs, 49).permute(0, 2, 1)
+ normalizer = F.fold(normalizer,
+ output_size=output_size,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride'])
+
+ x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
+ output_size=output_size,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride'])
+
+ x = F.unfold(x / normalizer,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride']).permute(
+ 0, 2, 1).contiguous().view(b, n, c)
+ x = self.conv2(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: shape is (B, T, H, W, C)
+ window_size (tuple[int]): window size
+ Returns:
+ windows: (B*num_windows, T*window_size*window_size, C)
+ """
+ B, T, H, W, C = x.shape
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
+ window_size[1], C)
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
+ -1, T * window_size[0] * window_size[1], C)
+ return windows
+
+
+def window_partition_noreshape(x, window_size):
+ """
+ Args:
+ x: shape is (B, T, H, W, C)
+ window_size (tuple[int]): window size
+ Returns:
+ windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
+ """
+ B, T, H, W, C = x.shape
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
+ window_size[1], C)
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
+ return windows
+
+
+def window_reverse(windows, window_size, T, H, W):
+ """
+ Args:
+ windows: shape is (num_windows*B, T, window_size, window_size, C)
+ window_size (tuple[int]): Window size
+ T (int): Temporal length of video
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, T, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
+ x = windows.view(B, H // window_size[0], W // window_size[1], T,
+ window_size[0], window_size[1], -1)
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """Temporal focal window attention
+ """
+ def __init__(self, dim, expand_size, window_size, focal_window,
+ focal_level, num_heads, qkv_bias, pool_method):
+
+ super().__init__()
+ self.dim = dim
+ self.expand_size = expand_size
+ self.window_size = window_size # Wh, Ww
+ self.pool_method = pool_method
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ if any(i > 0 for i in self.expand_size) and focal_level > 0:
+ # get mask for rolled k and rolled v
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
+ 0).flatten(0)
+ self.register_buffer("valid_ind_rolled",
+ mask_rolled.nonzero(as_tuple=False).view(-1))
+
+ if pool_method != "none" and focal_level > 1:
+ self.unfolds = nn.ModuleList()
+
+ # build relative position bias between local patch and pooled windows
+ for k in range(focal_level - 1):
+ stride = 2**k
+ kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
+ for i in self.focal_window)
+ # define unfolding operations
+ self.unfolds += [
+ nn.Unfold(kernel_size=kernel_size,
+ stride=stride,
+ padding=tuple(i // 2 for i in kernel_size))
+ ]
+
+ # define unfolding index for focal_level > 0
+ if k > 0:
+ mask = torch.zeros(kernel_size)
+ mask[(2**k) - 1:, (2**k) - 1:] = 1
+ self.register_buffer(
+ "valid_ind_unfold_{}".format(k),
+ mask.flatten(0).nonzero(as_tuple=False).view(-1))
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x_all, mask_all=None):
+ """
+ Args:
+ x: input features with shape of (B, T, Wh, Ww, C)
+ mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
+
+ output: (nW*B, Wh*Ww, C)
+ """
+ x = x_all[0]
+
+ B, T, nH, nW, C = x.shape
+ qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
+ C).permute(4, 0, 1, 2, 3, 5).contiguous()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
+
+ # partition q map
+ (q_windows, k_windows, v_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
+ contiguous().view(-1, self.num_heads, T * self.window_size[
+ 0] * self.window_size[1], C // self.num_heads), (q, k, v))
+ # q(k/v)_windows shape : [16, 4, 225, 128]
+
+ if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
+ (k_tl, v_tl) = map(
+ lambda t: torch.roll(t,
+ shifts=(-self.expand_size[0], -self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_tr, v_tr) = map(
+ lambda t: torch.roll(t,
+ shifts=(-self.expand_size[0], self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_bl, v_bl) = map(
+ lambda t: torch.roll(t,
+ shifts=(self.expand_size[0], -self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+ (k_br, v_br) = map(
+ lambda t: torch.roll(t,
+ shifts=(self.expand_size[0], self.
+ expand_size[1]),
+ dims=(2, 3)), (k, v))
+
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
+ lambda t: window_partition(t, self.window_size).view(
+ -1, T, self.window_size[0] * self.window_size[1], self.
+ num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
+ k_rolled = torch.cat(
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
+ 2).permute(0, 3, 1, 2, 4).contiguous()
+ v_rolled = torch.cat(
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
+ 2).permute(0, 3, 1, 2, 4).contiguous()
+
+ # mask out tokens in current window
+ k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
+ v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
+ temp_N = k_rolled.shape[3]
+ k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
+ C // self.num_heads)
+ v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
+ C // self.num_heads)
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
+ else:
+ k_rolled = k_windows
+ v_rolled = v_windows
+
+ # q(k/v)_windows shape : [16, 4, 225, 128]
+ # k_rolled.shape : [16, 4, 5, 165, 128]
+ # ideal expanded window size 153 ((5+2*2)*(9+2*4))
+ # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ k_pooled = []
+ v_pooled = []
+ for k in range(self.focal_level - 1):
+ stride = 2**k
+ # B, T, nWh, nWw, C
+ x_window_pooled = x_all[k + 1].permute(0, 3, 1, 2,
+ 4).contiguous()
+
+ nWh, nWw = x_window_pooled.shape[2:4]
+
+ # generate mask for pooled windows
+ mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
+ # unfold mask: [nWh*nWw//s//s, k*k, 1]
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
+ 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
+ view(nWh*nWw // stride // stride, -1, 1)
+
+ if k > 0:
+ valid_ind_unfold_k = getattr(
+ self, "valid_ind_unfold_{}".format(k))
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
+
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
+ x_window_masks = x_window_masks.masked_fill(
+ x_window_masks == 0,
+ float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
+ mask_all[k + 1] = x_window_masks
+
+ # generate k and v for pooled windows
+ qkv_pooled = self.qkv(x_window_pooled).reshape(
+ B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
+ 3).view(3, -1, C, nWh,
+ nWw).contiguous()
+ # B*T, C, nWh, nWw
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2]
+ # k_pooled_k shape: [5, 512, 4, 4]
+ # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
+
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: self.unfolds[k]
+ (t).view(B, T, C, self.unfolds[k].kernel_size[0], self.
+ unfolds[k].kernel_size[1], -1)
+ .permute(0, 5, 1, 3, 4, 2).contiguous().view(
+ -1, T, self.unfolds[k].kernel_size[0] * self.unfolds[
+ k].kernel_size[1], self.num_heads, C // self.
+ num_heads).permute(0, 3, 1, 2, 4).contiguous(),
+ # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
+ (k_pooled_k, v_pooled_k))
+ # k_pooled_k shape : [16, 4, 5, 45, 128]
+
+ # select valid unfolding index
+ if k > 0:
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: t[:, :, :, valid_ind_unfold_k],
+ (k_pooled_k, v_pooled_k))
+
+ k_pooled_k = k_pooled_k.view(
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
+ v_pooled_k = v_pooled_k.view(
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
+
+ k_pooled += [k_pooled_k]
+ v_pooled += [v_pooled_k]
+
+ # k_all (v_all) shape : [16, 4, 5 * 210, 128]
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
+ else:
+ k_all = k_rolled
+ v_all = v_rolled
+
+ N = k_all.shape[-2]
+ q_windows = q_windows * self.scale
+ # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
+ attn = (q_windows @ k_all.transpose(-2, -1))
+ # T * 45
+ window_area = T * self.window_size[0] * self.window_size[1]
+ # T * 165
+ window_area_rolled = k_rolled.shape[2]
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ offset = window_area_rolled
+ for k in range(self.focal_level - 1):
+ # add attentional mask
+ # mask_all[1] shape [1, 16, T * 45]
+
+ bias = tuple((i + 2**k - 1) for i in self.focal_window)
+
+ if mask_all[k + 1] is not None:
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
+ mask_all[k+1][:, :, None, None, :].repeat(
+ attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
+
+ offset += T * bias[0] * bias[1]
+
+ if mask_all[0] is not None:
+ nW = mask_all[0].shape[0]
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
+ window_area, N)
+ attn[:, :, :, :, :
+ window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
+ None, :, None, :, :]
+ attn = attn.view(-1, self.num_heads, window_area, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
+ C)
+ x = self.proj(x)
+ return x
+
+
+class TemporalFocalTransformerBlock(nn.Module):
+ r""" Temporal Focal Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ focal_level (int): The number level of focal window.
+ focal_window (int): Window size of each focal window.
+ n_vecs (int): Required for F3N.
+ t2t_params (int): T2T parameters for F3N.
+ """
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(5, 9),
+ mlp_ratio=4.,
+ qkv_bias=True,
+ pool_method="fc",
+ focal_level=2,
+ focal_window=(5, 9),
+ norm_layer=nn.LayerNorm,
+ n_vecs=None,
+ t2t_params=None):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.expand_size = tuple(i // 2 for i in window_size) # TODO
+ self.mlp_ratio = mlp_ratio
+ self.pool_method = pool_method
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ self.window_size_glo = self.window_size
+
+ self.pool_layers = nn.ModuleList()
+ if self.pool_method != "none":
+ for k in range(self.focal_level - 1):
+ window_size_glo = tuple(
+ math.floor(i / (2**k)) for i in self.window_size_glo)
+ self.pool_layers.append(
+ nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
+ self.pool_layers[-1].weight.data.fill_(
+ 1. / (window_size_glo[0] * window_size_glo[1]))
+ self.pool_layers[-1].bias.data.fill_(0)
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = WindowAttention(dim,
+ expand_size=self.expand_size,
+ window_size=self.window_size,
+ focal_window=focal_window,
+ focal_level=focal_level,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ pool_method=pool_method)
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
+
+ def forward(self, x):
+ output_size = x[1]
+ x = x[0]
+
+ B, T, H, W, C = x.shape
+
+ shortcut = x
+ x = self.norm1(x)
+
+ shifted_x = x
+
+ x_windows_all = [shifted_x]
+ x_window_masks_all = [None]
+
+ # partition windows tuple(i // 2 for i in window_size)
+ if self.focal_level > 1 and self.pool_method != "none":
+ # if we add coarser granularity and the pool method is not none
+ for k in range(self.focal_level - 1):
+ window_size_glo = tuple(
+ math.floor(i / (2**k)) for i in self.window_size_glo)
+ pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
+ pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
+ H_pool = pooled_h * window_size_glo[0]
+ W_pool = pooled_w * window_size_glo[1]
+
+ x_level_k = shifted_x
+ # trim or pad shifted_x depending on the required size
+ if H > H_pool:
+ trim_t = (H - H_pool) // 2
+ trim_b = H - H_pool - trim_t
+ x_level_k = x_level_k[:, :, trim_t:-trim_b]
+ elif H < H_pool:
+ pad_t = (H_pool - H) // 2
+ pad_b = H_pool - H - pad_t
+ x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
+
+ if W > W_pool:
+ trim_l = (W - W_pool) // 2
+ trim_r = W - W_pool - trim_l
+ x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
+ elif W < W_pool:
+ pad_l = (W_pool - W) // 2
+ pad_r = W_pool - W - pad_l
+ x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
+
+ x_windows_noreshape = window_partition_noreshape(
+ x_level_k.contiguous(), window_size_glo
+ ) # B, nw, nw, T, window_size, window_size, C
+ nWh, nWw = x_windows_noreshape.shape[1:3]
+ x_windows_noreshape = x_windows_noreshape.view(
+ B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
+ C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
+ x_windows_pooled = self.pool_layers[k](
+ x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
+
+ x_windows_all += [x_windows_pooled]
+ x_window_masks_all += [None]
+
+ # nW*B, T*window_size*window_size, C
+ attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all)
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, T, self.window_size[0],
+ self.window_size[1], C)
+ shifted_x = window_reverse(attn_windows, self.window_size, T, H,
+ W) # B T H' W' C
+
+ # FFN
+ x = shortcut + shifted_x
+ y = self.norm2(x)
+ x = x + self.mlp(y.view(B, T * H * W, C), output_size).view(
+ B, T, H, W, C)
+
+ return x, output_size
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/release_model/README.md b/phantom/submodules/phantom-E2FGVI/E2FGVI/release_model/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1b2ae3bcc2c4e717adca2d375352b88de88156a6
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/release_model/README.md
@@ -0,0 +1,11 @@
+Place the downloaded model here.
+
+:link: **Download Links:** [[Google Drive](https://drive.google.com/file/d/1tNJMTJ2gmWdIXJoHVi5-H504uImUiJW9/view?usp=sharing)] [[Baidu Disk](https://pan.baidu.com/s/1qXAErbilY_n_Fh9KB8UF7w?pwd=lsjw)]
+
+The directory structure will be arranged as:
+```
+release_model
+ |- E2FGVI-CVPR22.pth
+ |- i3d_rgb_imagenet.pt (for evaluating VFID metric)
+ |- README.md
+```
\ No newline at end of file
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/test.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..448f10c3d92843f66278b5cda867bdad400ca2d3
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/test.py
@@ -0,0 +1,224 @@
+# -*- coding: utf-8 -*-
+import cv2
+from PIL import Image
+import numpy as np
+import importlib
+import os
+import argparse
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+from matplotlib import animation
+import torch
+
+from core.utils import to_tensors
+
+parser = argparse.ArgumentParser(description="E2FGVI")
+parser.add_argument("-v", "--video", type=str, required=True)
+parser.add_argument("-c", "--ckpt", type=str, required=True)
+parser.add_argument("-m", "--mask", type=str, required=True)
+parser.add_argument("--model", type=str, choices=['e2fgvi', 'e2fgvi_hq'])
+parser.add_argument("--step", type=int, default=10)
+parser.add_argument("--num_ref", type=int, default=-1)
+parser.add_argument("--neighbor_stride", type=int, default=5)
+parser.add_argument("--savefps", type=int, default=24)
+
+# args for e2fgvi_hq (which can handle videos with arbitrary resolution)
+parser.add_argument("--set_size", action='store_true', default=False)
+parser.add_argument("--width", type=int)
+parser.add_argument("--height", type=int)
+
+args = parser.parse_args()
+
+ref_length = args.step # ref_step
+num_ref = args.num_ref
+neighbor_stride = args.neighbor_stride
+default_fps = args.savefps
+
+
+# sample reference frames from the whole video
+def get_ref_index(f, neighbor_ids, length):
+ ref_index = []
+ if num_ref == -1:
+ for i in range(0, length, ref_length):
+ if i not in neighbor_ids:
+ ref_index.append(i)
+ else:
+ start_idx = max(0, f - ref_length * (num_ref // 2))
+ end_idx = min(length, f + ref_length * (num_ref // 2))
+ for i in range(start_idx, end_idx + 1, ref_length):
+ if i not in neighbor_ids:
+ if len(ref_index) > num_ref:
+ break
+ ref_index.append(i)
+ return ref_index
+
+
+# read frame-wise masks
+def read_mask(mpath, size):
+ masks = []
+ mnames = os.listdir(mpath)
+ mnames.sort()
+ for mp in mnames:
+ m = Image.open(os.path.join(mpath, mp))
+ m = m.resize(size, Image.NEAREST)
+ m = np.array(m.convert('L'))
+ m = np.array(m > 0).astype(np.uint8)
+ m = cv2.dilate(m,
+ cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
+ iterations=4)
+ masks.append(Image.fromarray(m * 255))
+ return masks
+
+
+# read frames from video
+def read_frame_from_videos(args):
+ vname = args.video
+ frames = []
+ if args.use_mp4:
+ vidcap = cv2.VideoCapture(vname)
+ success, image = vidcap.read()
+ count = 0
+ while success:
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
+ frames.append(image)
+ success, image = vidcap.read()
+ count += 1
+ else:
+ lst = os.listdir(vname)
+ lst.sort()
+ fr_lst = [vname + '/' + name for name in lst]
+ for fr in fr_lst:
+ image = cv2.imread(fr)
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
+ frames.append(image)
+ return frames
+
+
+# resize frames
+def resize_frames(frames, size=None):
+ if size is not None:
+ frames = [f.resize(size) for f in frames]
+ else:
+ size = frames[0].size
+ return frames, size
+
+
+def main_worker():
+ # set up models
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ if args.model == "e2fgvi":
+ size = (432, 240)
+ elif args.set_size:
+ size = (args.width, args.height)
+ else:
+ size = None
+
+ net = importlib.import_module('model.' + args.model)
+ model = net.InpaintGenerator().to(device)
+ data = torch.load(args.ckpt, map_location=device)
+ model.load_state_dict(data)
+ print(f'Loading model from: {args.ckpt}')
+ model.eval()
+
+ # prepare datset
+ args.use_mp4 = True if args.video.endswith('.mp4') else False
+ print(
+ f'Loading videos and masks from: {args.video} | INPUT MP4 format: {args.use_mp4}'
+ )
+ frames = read_frame_from_videos(args)
+ frames, size = resize_frames(frames, size)
+ h, w = size[1], size[0]
+ video_length = len(frames)
+ imgs = to_tensors()(frames).unsqueeze(0) * 2 - 1
+ frames = [np.array(f).astype(np.uint8) for f in frames]
+
+ masks = read_mask(args.mask, size)
+ binary_masks = [
+ np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) for m in masks
+ ]
+ masks = to_tensors()(masks).unsqueeze(0)
+ imgs, masks = imgs.to(device), masks.to(device)
+ comp_frames = [None] * video_length
+
+ # completing holes by e2fgvi
+ print(f'Start test...')
+ for f in tqdm(range(0, video_length, neighbor_stride)):
+ neighbor_ids = [
+ i for i in range(max(0, f - neighbor_stride),
+ min(video_length, f + neighbor_stride + 1))
+ ]
+ ref_ids = get_ref_index(f, neighbor_ids, video_length)
+ selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
+ selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
+ with torch.no_grad():
+ masked_imgs = selected_imgs * (1 - selected_masks)
+ mod_size_h = 60
+ mod_size_w = 108
+ h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
+ w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
+ masked_imgs = torch.cat(
+ [masked_imgs, torch.flip(masked_imgs, [3])],
+ 3)[:, :, :, :h + h_pad, :]
+ masked_imgs = torch.cat(
+ [masked_imgs, torch.flip(masked_imgs, [4])],
+ 4)[:, :, :, :, :w + w_pad]
+ pred_imgs, _ = model(masked_imgs, len(neighbor_ids))
+ pred_imgs = pred_imgs[:, :, :h, :w]
+ pred_imgs = (pred_imgs + 1) / 2
+ pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
+ for i in range(len(neighbor_ids)):
+ idx = neighbor_ids[i]
+ img = np.array(pred_imgs[i]).astype(
+ np.uint8) * binary_masks[idx] + frames[idx] * (
+ 1 - binary_masks[idx])
+ if comp_frames[idx] is None:
+ comp_frames[idx] = img
+ else:
+ comp_frames[idx] = comp_frames[idx].astype(
+ np.float32) * 0.5 + img.astype(np.float32) * 0.5
+
+ # saving videos
+ print('Saving videos...')
+ save_dir_name = 'results'
+ ext_name = '_results.mp4'
+ save_base_name = args.video.split('/')[-1]
+ save_name = save_base_name.replace(
+ '.mp4', ext_name) if args.use_mp4 else save_base_name + ext_name
+ if not os.path.exists(save_dir_name):
+ os.makedirs(save_dir_name)
+ save_path = os.path.join(save_dir_name, save_name)
+ writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"),
+ default_fps, size)
+ for f in range(video_length):
+ comp = comp_frames[f].astype(np.uint8)
+ writer.write(cv2.cvtColor(comp, cv2.COLOR_BGR2RGB))
+ writer.release()
+ print(f'Finish test! The result video is saved in: {save_path}.')
+
+ # show results
+ print('Let us enjoy the result!')
+ fig = plt.figure('Let us enjoy the result')
+ ax1 = fig.add_subplot(1, 2, 1)
+ ax1.axis('off')
+ ax1.set_title('Original Video')
+ ax2 = fig.add_subplot(1, 2, 2)
+ ax2.axis('off')
+ ax2.set_title('Our Result')
+ imdata1 = ax1.imshow(frames[0])
+ imdata2 = ax2.imshow(comp_frames[0].astype(np.uint8))
+
+ def update(idx):
+ imdata1.set_data(frames[idx])
+ imdata2.set_data(comp_frames[idx].astype(np.uint8))
+
+ fig.tight_layout()
+ anim = animation.FuncAnimation(fig,
+ update,
+ frames=len(frames),
+ interval=50)
+ plt.show()
+
+
+if __name__ == '__main__':
+ main_worker()
diff --git a/phantom/submodules/phantom-E2FGVI/E2FGVI/train.py b/phantom/submodules/phantom-E2FGVI/E2FGVI/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1770db607ae1eb2af3f5a2ce3cd96fa629602d78
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/E2FGVI/train.py
@@ -0,0 +1,89 @@
+import os
+import json
+import argparse
+from shutil import copyfile
+
+import torch
+import torch.multiprocessing as mp
+
+from core.trainer import Trainer
+from core.dist import (
+ get_world_size,
+ get_local_rank,
+ get_global_rank,
+ get_master_ip,
+)
+
+parser = argparse.ArgumentParser(description='E2FGVI')
+parser.add_argument('-c',
+ '--config',
+ default='configs/train_e2fgvi.json',
+ type=str)
+parser.add_argument('-p', '--port', default='23455', type=str)
+args = parser.parse_args()
+
+
+def main_worker(rank, config):
+ if 'local_rank' not in config:
+ config['local_rank'] = config['global_rank'] = rank
+ if config['distributed']:
+ torch.cuda.set_device(int(config['local_rank']))
+ torch.distributed.init_process_group(backend='nccl',
+ init_method=config['init_method'],
+ world_size=config['world_size'],
+ rank=config['global_rank'],
+ group_name='mtorch')
+ print('using GPU {}-{} for training'.format(int(config['global_rank']),
+ int(config['local_rank'])))
+
+ config['save_dir'] = os.path.join(
+ config['save_dir'],
+ '{}_{}'.format(config['model']['net'],
+ os.path.basename(args.config).split('.')[0]))
+
+ config['save_metric_dir'] = os.path.join(
+ './scores',
+ '{}_{}'.format(config['model']['net'],
+ os.path.basename(args.config).split('.')[0]))
+
+ if torch.cuda.is_available():
+ config['device'] = torch.device("cuda:{}".format(config['local_rank']))
+ else:
+ config['device'] = 'cpu'
+
+ if (not config['distributed']) or config['global_rank'] == 0:
+ os.makedirs(config['save_dir'], exist_ok=True)
+ os.makedirs(config['save_metric_dir'], exist_ok=True)
+ config_path = os.path.join(config['save_dir'],
+ args.config.split('/')[-1])
+ if not os.path.isfile(config_path):
+ copyfile(args.config, config_path)
+ print('[**] create folder {}'.format(config['save_dir']))
+
+ trainer = Trainer(config)
+ trainer.train()
+
+
+if __name__ == "__main__":
+
+ torch.backends.cudnn.benchmark = True
+
+ mp.set_sharing_strategy('file_system')
+
+ # loading configs
+ config = json.load(open(args.config))
+
+ # setting distributed configurations
+ config['world_size'] = get_world_size()
+ config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
+ config['distributed'] = True if config['world_size'] > 1 else False
+ print(config['world_size'])
+ # setup distributed parallel training environments
+ if get_master_ip() == "127.0.0.1":
+ # manually launch distributed processes
+ mp.spawn(main_worker, nprocs=config['world_size'], args=(config, ))
+ else:
+ # multiple processes have been launched by openmpi
+ config['local_rank'] = get_local_rank()
+ config['global_rank'] = get_global_rank()
+ main_worker(-1, config)
diff --git a/phantom/submodules/phantom-E2FGVI/LICENSE b/phantom/submodules/phantom-E2FGVI/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..17bc97bff80068baf08757e6e2ffd03a2c1208d4
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/LICENSE
@@ -0,0 +1,163 @@
+## creative commons
+
+# Attribution-NonCommercial 4.0 International
+
+Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
+
+### Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
+
+* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
+
+* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
+
+## Creative Commons Attribution-NonCommercial 4.0 International Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
+
+### Section 1 – Definitions.
+
+a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
+
+b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
+
+c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
+
+d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
+
+e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
+
+f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
+
+g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
+
+h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
+
+i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
+
+j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
+
+k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
+
+l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
+
+### Section 2 – Scope.
+
+a. ___License grant.___
+
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
+
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
+
+ B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
+
+ 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
+
+ 3. __Term.__ The term of this Public License is specified in Section 6(a).
+
+ 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
+
+ 5. __Downstream recipients.__
+
+ A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
+
+ B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
+
+ 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
+
+b. ___Other rights.___
+
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this Public License.
+
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
+
+### Section 3 – License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the following conditions.
+
+a. ___Attribution.___
+
+ 1. If You Share the Licensed Material (including in modified form), You must:
+
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
+
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
+
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
+
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
+
+### Section 4 – Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
+
+a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
+
+b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
+
+c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
+
+### Section 5 – Disclaimer of Warranties and Limitation of Liability.
+
+a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
+
+b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
+
+c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
+
+### Section 6 – Term and Termination.
+
+a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
+
+b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
+
+c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
+
+d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
+
+### Section 7 – Other Terms and Conditions.
+
+a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
+
+b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
+
+### Section 8 – Interpretation.
+
+a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
+
+b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
+
+c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
+
+d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
+
+> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
+>
+> Creative Commons may be contacted at creativecommons.org
+
+Copyright (c) 2022 MCG-NKU
diff --git a/phantom/submodules/phantom-E2FGVI/README.md b/phantom/submodules/phantom-E2FGVI/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..70ed3db1788508ea2887f78d48b629f85d5a4d8a
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/README.md
@@ -0,0 +1,297 @@
+# E2FGVI (CVPR 2022)
+[](https://paperswithcode.com/sota/video-inpainting-on-davis?p=towards-an-end-to-end-framework-for-flow)
+[](https://paperswithcode.com/sota/video-inpainting-on-youtube-vos?p=towards-an-end-to-end-framework-for-flow)
+
+
+
+
+English | [简体中文](README_zh-CN.md)
+
+This repository contains the official implementation of the following paper:
+> **Towards An End-to-End Framework for Flow-Guided Video Inpainting**
+> Zhen Li#, Cheng-Ze Lu#, Jianhua Qin, Chun-Le Guo*, Ming-Ming Cheng
+> IEEE/CVF Conference on Computer Vision and Pattern Recognition (**CVPR**), 2022
+
+[[Paper](https://arxiv.org/abs/2204.02663)]
+[[Demo Video (Youtube)](https://www.youtube.com/watch?v=N--qC3T2wc4)]
+[[演示视频 (B站)](https://www.bilibili.com/video/BV1Ta411n7eH?spm_id_from=333.999.0.0)]
+[[MindSpore Implementation](https://github.com/Dragoniss/minspore-phase2-E2FGVI)]
+[Project Page (TBD)]
+[Poster (TBD)]
+
+You can try our colab demo here: [](https://colab.research.google.com/drive/12rwY2gtG8jVWlNx9pjmmM8uGmh5ue18G?usp=sharing)
+
+## :star: News
+- *2022.05.15:* We release E2FGVI-HQ, which can handle videos with **arbitrary resolution**. This model could generalize well to much higher resolutions, while it only used 432x240 videos for training. Besides, it performs **better** than our original model on both PSNR and SSIM metrics.
+:link: Download links: [[Google Drive](https://drive.google.com/file/d/10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3/view?usp=sharing)] [[Baidu Disk](https://pan.baidu.com/s/1jfm1oFU1eIy-IRfuHP8YXw?pwd=ssb3)] :movie_camera: Demo video: [[Youtube](https://www.youtube.com/watch?v=N--qC3T2wc4)] [[B站](https://www.bilibili.com/video/BV1Ta411n7eH?spm_id_from=333.999.0.0)]
+
+- *2022.04.06:* Our code is publicly available.
+## Demo
+
+
+
+### More examples (click for details):
+
+
+
+
+
+
+ Coco (click me)
+
+
+
+
+
+
+
+ Tennis
+
+
+
+
+
+
+
+
+
+ Space
+
+
+
+
+
+
+
+ Motocross
+
+
+
+
+
+
+
+## Overview
+
+
+### :rocket: Highlights:
+- **SOTA performance**: The proposed E2FGVI achieves significant improvements on all quantitative metrics in comparison with SOTA methods.
+- **Highly effiency**: Our method processes 432 × 240 videos at 0.12 seconds per frame on a Titan XP GPU, which is nearly 15× faster than previous flow-based methods. Besides, our method has the lowest FLOPs among all compared SOTA
+methods.
+
+## Work in Progress
+- [ ] Update website page
+- [ ] Hugging Face demo
+- [ ] Efficient inference
+
+## Dependencies and Installation
+
+1. Clone Repo
+
+ ```bash
+ git clone https://github.com/MCG-NKU/E2FGVI.git
+ ```
+
+2. Create Conda Environment and Install Dependencies
+
+ ```bash
+ conda env create -f environment.yml
+ conda activate e2fgvi
+ ```
+ - Python >= 3.7
+ - PyTorch >= 1.5
+ - CUDA >= 9.2
+ - [mmcv-full](https://github.com/open-mmlab/mmcv#installation) (following the pipeline to install)
+
+ If the `environment.yml` file does not work for you, please follow [this issue](https://github.com/MCG-NKU/E2FGVI/issues/3) to solve the problem.
+
+## Get Started
+### Prepare pretrained models
+Before performing the following steps, please download our pretrained model first.
+
+
+
+Then, unzip the file and place the models to `release_model` directory.
+
+The directory structure will be arranged as:
+```
+release_model
+ |- E2FGVI-CVPR22.pth
+ |- E2FGVI-HQ-CVPR22.pth
+ |- i3d_rgb_imagenet.pt (for evaluating VFID metric)
+ |- README.md
+```
+
+### Quick test
+We provide two examples in the [`examples`](./examples) directory.
+
+Run the following command to enjoy them:
+```shell
+# The first example (using split video frames)
+python test.py --model e2fgvi (or e2fgvi_hq) --video examples/tennis --mask examples/tennis_mask --ckpt release_model/E2FGVI-CVPR22.pth (or release_model/E2FGVI-HQ-CVPR22.pth)
+# The second example (using mp4 format video)
+python test.py --model e2fgvi (or e2fgvi_hq) --video examples/schoolgirls.mp4 --mask examples/schoolgirls_mask --ckpt release_model/E2FGVI-CVPR22.pth (or release_model/E2FGVI-HQ-CVPR22.pth)
+```
+The inpainting video will be saved in the `results` directory.
+Please prepare your own **mp4 video** (or **split frames**) and **frame-wise masks** if you want to test more cases.
+
+*Note:* E2FGVI always rescales the input video to a fixed resolution (432x240), while E2FGVI-HQ does not change the resolution of the input video. If you want to custom the output resolution, please use the `--set_size` flag and set the values of `--width` and `--height`.
+
+Example:
+```shell
+# Using this command to output a 720p video
+python test.py --model e2fgvi_hq --video --mask --ckpt release_model/E2FGVI-HQ-CVPR22.pth --set_size --width 1280 --height 720
+```
+
+
+### Prepare dataset for training and evaluation
+
+
+
+
Dataset
+
YouTube-VOS
+
DAVIS
+
+
+
+
+
Details
+
For training (3,471) and evaluation (508)
+
For evaluation (50 in 90)
+
+
Images
+
[Official Link] (Download train and test all frames)
+
+The training and test split files are provided in `datasets/`.
+
+For each dataset, you should place `JPEGImages` to `datasets/`.
+
+Then, run `sh datasets/zip_dir.sh` (**Note**: please edit the folder path accordingly) for compressing each video in `datasets//JPEGImages`.
+
+Unzip downloaded mask files to `datasets`.
+
+The `datasets` directory structure will be arranged as: (**Note**: please check it carefully)
+```
+datasets
+ |- davis
+ |- JPEGImages
+ |- .zip
+ |- .zip
+ |- test_masks
+ |-
+ |- 00000.png
+ |- 00001.png
+ |- train.json
+ |- test.json
+ |- youtube-vos
+ |- JPEGImages
+ |- .zip
+ |- .zip
+ |- test_masks
+ |-
+ |- 00000.png
+ |- 00001.png
+ |- train.json
+ |- test.json
+ |- zip_file.sh
+```
+### Evaluation
+Run one of the following commands for evaluation:
+```shell
+ # For evaluating E2FGVI model
+ python evaluate.py --model e2fgvi --dataset --data_root datasets/ --ckpt release_model/E2FGVI-CVPR22.pth
+ # For evaluating E2FGVI-HQ model
+ python evaluate.py --model e2fgvi_hq --dataset --data_root datasets/ --ckpt release_model/E2FGVI-HQ-CVPR22.pth
+
+```
+You will get scores as paper reported if you evaluate E2FGVI.
+The scores of E2FGVI-HQ can be found in [[Prepare pretrained models](https://github.com/MCG-NKU/E2FGVI#prepare-pretrained-models)].
+
+The scores will also be saved in the `results/_` directory.
+
+Please `--save_results` for further [evaluating temporal warping error](https://github.com/phoenix104104/fast_blind_video_consistency#evaluation).
+
+### Training
+Our training configures are provided in [`train_e2fgvi.json`](./configs/train_e2fgvi.json) (for E2FGVI) and [`train_e2fgvi_hq.json`](./configs/train_e2fgvi_hq.json) (for E2FGVI-HQ).
+
+Run one of the following commands for training:
+```shell
+ # For training E2FGVI
+ python train.py -c configs/train_e2fgvi.json
+ # For training E2FGVI-HQ
+ python train.py -c configs/train_e2fgvi_hq.json
+```
+You could run the same command if you want to resume your training.
+
+The training loss can be monitored by running:
+```shell
+tensorboard --logdir release_model
+```
+
+You could follow [this pipeline](https://github.com/MCG-NKU/E2FGVI#evaluation) to evaluate your model.
+## Results
+
+### Quantitative results
+
+## Citation
+
+ If you find our repo useful for your research, please consider citing our paper:
+
+ ```bibtex
+ @inproceedings{liCvpr22vInpainting,
+ title={Towards An End-to-End Framework for Flow-Guided Video Inpainting},
+ author={Li, Zhen and Lu, Cheng-Ze and Qin, Jianhua and Guo, Chun-Le and Cheng, Ming-Ming},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year={2022}
+ }
+ ```
+## Contact
+
+If you have any question, please feel free to contact us via `zhenli1031ATgmail.com` or `czlu919AToutlook.com`.
+
+## License
+Licensed under a [Creative Commons Attribution-NonCommercial 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/) for Non-commercial use only.
+Any commercial use should get formal permission first.
+
+## Acknowledgement
+
+This repository is maintained by [Zhen Li](https://paper99.github.io) and [Cheng-Ze Lu](https://github.com/LGYoung).
+
+This code is based on [STTN](https://github.com/researchmm/STTN), [FuseFormer](https://github.com/ruiliu-ai/FuseFormer), [Focal-Transformer](https://github.com/microsoft/Focal-Transformer), and [MMEditing](https://github.com/open-mmlab/mmediting).
diff --git a/phantom/submodules/phantom-E2FGVI/README_zh-CN.md b/phantom/submodules/phantom-E2FGVI/README_zh-CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..a00726ff6e27dc34f555c7a236c5738686aedba5
--- /dev/null
+++ b/phantom/submodules/phantom-E2FGVI/README_zh-CN.md
@@ -0,0 +1,294 @@
+# E2FGVI (CVPR 2022)-简体中文
+[](https://paperswithcode.com/sota/video-inpainting-on-davis?p=towards-an-end-to-end-framework-for-flow)
+[](https://paperswithcode.com/sota/video-inpainting-on-youtube-vos?p=towards-an-end-to-end-framework-for-flow)
+
+
+
+
+[English](README.md) | 简体中文
+
+本项目包含了以下论文的官方实现:
+> **Towards An End-to-End Framework for Flow-Guided Video Inpainting**
+> Zhen Li#, Cheng-Ze Lu#, Jianhua Qin, Chun-Le Guo*, Ming-Ming Cheng
+> IEEE/CVF Conference on Computer Vision and Pattern Recognition (**CVPR**), 2022
+
+[[论文](https://arxiv.org/abs/2204.02663)]
+[[Demo Video (Youtube)](https://www.youtube.com/watch?v=N--qC3T2wc4)]
+[[演示视频 (B站)](https://www.bilibili.com/video/BV1Ta411n7eH?spm_id_from=333.999.0.0)]
+[项目主页 (待定)]
+[海报 (待定)]
+
+Colab实例:[](https://colab.research.google.com/drive/12rwY2gtG8jVWlNx9pjmmM8uGmh5ue18G?usp=sharing)
+
+## :star: 最新进展
+- *2022.05.15:* 可适配**任意分辨率**的E2FGVI-HQ已发布.该模型仅需要在 432x240 的分辨率下进行训练, 即可适配更高分辨率下的推理任务.并且, 该模型比原先模型能够取得**更好**的PSNR/SSIM指标.
+:link: 下载链接: [[Google Drive](https://drive.google.com/file/d/10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3/view?usp=sharing)] [[Baidu Disk](https://pan.baidu.com/s/1jfm1oFU1eIy-IRfuHP8YXw?pwd=ssb3)] :movie_camera: 演示视频: [[Youtube](https://www.youtube.com/watch?v=N--qC3T2wc4)] [[B站](https://www.bilibili.com/video/BV1Ta411n7eH?spm_id_from=333.999.0.0)]
+
+- *2022.04.06:* 代码公开发布.
+## 演示视频
+
+
+
+### 更多示例 (点击查看详情):
+
+
+
+## Requirements
+
+* Python >= 3.7.0
+* MMPose >= 0.23.0
+* MMDetection >= 2.21.0
+
+## Tutorials
+
+* [Get started with MMPose Webcam API (Chinese)](/tools/webcam/docs/get_started_cn.md)
+* [Build a Webcam App: A Step-by-step Instruction (Chinese)](/tools/webcam/docs/example_cn.md)
+
+## Examples
+
+* [Pose Estimation](/tools/webcam/configs/examples/): A simple example to estimate and visualize human/animal pose.
+* [Eye Effects](/tools/webcam/configs/eyes/): Apply sunglasses and bug-eye effects.
+* [Face Swap](/tools/webcam/configs/face_swap/): Everybody gets someone else's face.
+* [Meow Dwen Dwen](/tools/webcam/configs/meow_dwen_dwen/): Dress up your cat in Bing Dwen Dwen costume.
+* [Super Saiyan](/tools/webcam/configs/supersaiyan/): Super Saiyan transformation!
+* [New Year](/tools/webcam/configs/newyear/): Set off some firecrackers to celebrate Chinese New Year.
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/background/README.md b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/background/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7be8782e38717c6d537648e313921fb8c48b124e
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/background/README.md
@@ -0,0 +1,73 @@
+# Matting Effects
+
+We can apply background matting to the videos.
+
+## Instruction
+
+### Get started
+
+Launch the demo from the mmpose root directory:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/background/background.py
+```
+
+### Hotkeys
+
+| Hotkey | Function |
+| -- | -- |
+| b | Toggle the background matting effect on/off. |
+| h | Show help information. |
+| m | Show the monitoring information. |
+| q | Exit. |
+
+Note that the demo will automatically save the output video into a file `record.mp4`.
+
+### Configuration
+
+- **Choose a detection model**
+
+Users can choose detection models from the [MMDetection Model Zoo](https://mmdetection.readthedocs.io/en/v2.20.0/model_zoo.html). Just set the `model_config` and `model_checkpoint` in the detector node accordingly, and the model will be automatically downloaded and loaded.
+Note that in order to perform background matting, the model should be able to produce segmentation masks.
+
+```python
+# 'DetectorNode':
+# This node performs object detection from the frame image using an
+# MMDetection model.
+dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/mask_rcnn_r50_fpn_2x_coco.py',
+ model_checkpoint='https://download.openmmlab.com/'
+ 'mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/'
+ 'mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392'
+ '__segm_mAP-0.354_20200505_003907-3e542a40.pth',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+```
+
+- **Run the demo without GPU**
+
+If you don't have GPU and CUDA in your device, the demo can run with only CPU by setting `device='cpu'` in all model nodes. For example:
+
+```python
+dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/mask_rcnn_r50_fpn_2x_coco.py',
+ model_checkpoint='https://download.openmmlab.com/'
+ 'mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/'
+ 'mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392'
+ '__segm_mAP-0.354_20200505_003907-3e542a40.pth',
+ device='cpu',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+```
+
+- **Debug webcam and display**
+
+You can launch the webcam runner with a debug config:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/examples/test_camera.py
+```
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/background/background.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/background/background.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb9f4d616e929cbe7f3c789a729ce2c07d40b9a1
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/background/background.py
@@ -0,0 +1,93 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+runner = dict(
+ # Basic configurations of the runner
+ name='Matting Effects',
+ camera_id=0,
+ camera_fps=10,
+ synchronous=False,
+ # Define nodes.
+ # The configuration of a node usually includes:
+ # 1. 'type': Node class name
+ # 2. 'name': Node name
+ # 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the
+ # input and output buffer names. This may depend on the node class.
+ # 4. 'enable_key': assign a hot-key to toggle enable/disable this node.
+ # This may depend on the node class.
+ # 5. Other class-specific arguments
+ nodes=[
+ # 'DetectorNode':
+ # This node performs object detection from the frame image using an
+ # MMDetection model.
+ dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/mask_rcnn_r50_fpn_2x_coco.py',
+ model_checkpoint='https://download.openmmlab.com/'
+ 'mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/'
+ 'mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392'
+ '__segm_mAP-0.354_20200505_003907-3e542a40.pth',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+ # 'TopDownPoseEstimatorNode':
+ # This node performs keypoint detection from the frame image using an
+ # MMPose top-down model. Detection results is needed.
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Human Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_mbv3_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://openmmlab-share.oss-cn-hangz'
+ 'hou.aliyuncs.com/mmpose/top_down/vipnas/vipnas_mbv3_co'
+ 'co_wholebody_256x192_dark-e2158108_20211205.pth',
+ cls_names=['person'],
+ input_buffer='det_result',
+ output_buffer='human_pose'),
+ # 'ModelResultBindingNode':
+ # This node binds the latest model inference result with the current
+ # frame. (This means the frame image and inference result may be
+ # asynchronous).
+ dict(
+ type='ModelResultBindingNode',
+ name='ResultBinder',
+ frame_buffer='_frame_', # `_frame_` is a runner-reserved buffer
+ result_buffer='human_pose',
+ output_buffer='frame'),
+ # 'MattingNode':
+ # This node draw the matting visualization result in the frame image.
+ # mask results is needed.
+ dict(
+ type='BackgroundNode',
+ name='Visualizer',
+ enable_key='b',
+ enable=True,
+ frame_buffer='frame',
+ output_buffer='vis_bg',
+ cls_names=['person']),
+ # 'NoticeBoardNode':
+ # This node show a notice board with given content, e.g. help
+ # information.
+ dict(
+ type='NoticeBoardNode',
+ name='Helper',
+ enable_key='h',
+ frame_buffer='vis_bg',
+ output_buffer='vis',
+ content_lines=[
+ 'This is a demo for background changing effects. Have fun!',
+ '', 'Hot-keys:', '"b": Change background',
+ '"h": Show help information',
+ '"m": Show diagnostic information', '"q": Exit'
+ ],
+ ),
+ # 'MonitorNode':
+ # This node show diagnostic information in the frame image. It can
+ # be used for debugging or monitoring system resource status.
+ dict(
+ type='MonitorNode',
+ name='Monitor',
+ enable_key='m',
+ enable=False,
+ frame_buffer='vis',
+ output_buffer='_display_') # `_frame_` is a runner-reserved buffer
+ ])
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/examples/README.md b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/examples/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ec9b961d284631478b3c326872d75942437a7f0e
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/examples/README.md
@@ -0,0 +1,110 @@
+# Pose Estimation Demo
+
+This demo performs human bounding box and keypoint detection, and visualizes results.
+
+
+
+
+
+## Instruction
+
+### Get started
+
+Launch the demo from the mmpose root directory:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/examples/pose_estimation.py
+```
+
+### Hotkeys
+
+| Hotkey | Function |
+| -- | -- |
+| v | Toggle the pose visualization on/off. |
+| h | Show help information. |
+| m | Show the monitoring information. |
+| q | Exit. |
+
+Note that the demo will automatically save the output video into a file `record.mp4`.
+
+### Configuration
+
+- **Choose a detection model**
+
+Users can choose detection models from the [MMDetection Model Zoo](https://mmdetection.readthedocs.io/en/v2.20.0/model_zoo.html). Just set the `model_config` and `model_checkpoint` in the detector node accordingly, and the model will be automatically downloaded and loaded.
+
+```python
+# 'DetectorNode':
+ # This node performs object detection from the frame image using an
+ # MMDetection model.
+dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco.py',
+ model_checkpoint='https://download.openmmlab.com'
+ '/mmdetection/v2.0/ssd/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+ 'scratch_600e_coco_20210629_110627-974d9307.pth',
+ input_buffer='_input_',
+ output_buffer='det_result')
+```
+
+- **Choose a or more pose models**
+
+In this demo we use two [top-down](https://github.com/open-mmlab/mmpose/tree/master/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap) pose estimation models for humans and animals respectively. Users can choose models from the [MMPose Model Zoo](https://mmpose.readthedocs.io/en/latest/modelzoo.html). To apply different pose models on different instance types, you can add multiple pose estimator nodes with `cls_names` set accordingly.
+
+```python
+# 'TopDownPoseEstimatorNode':
+# This node performs keypoint detection from the frame image using an
+# MMPose top-down model. Detection results is needed.
+dict(
+ type='TopDownPoseEstimatorNode',
+ name='Human Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_mbv3_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://openmmlab-share.oss-cn-hangz'
+ 'hou.aliyuncs.com/mmpose/top_down/vipnas/vipnas_mbv3_co'
+ 'co_wholebody_256x192_dark-e2158108_20211205.pth',
+ cls_names=['person'],
+ input_buffer='det_result',
+ output_buffer='human_pose'),
+dict(
+ type='TopDownPoseEstimatorNode',
+ name='Animal Pose Estimator',
+ model_config='configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap'
+ '/animalpose/hrnet_w32_animalpose_256x256.py',
+ model_checkpoint='https://download.openmmlab.com/mmpose/animal/'
+ 'hrnet/hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth',
+ cls_names=['cat', 'dog', 'horse', 'sheep', 'cow'],
+ input_buffer='human_pose',
+ output_buffer='animal_pose')
+```
+
+- **Run the demo without GPU**
+
+If you don't have GPU and CUDA in your device, the demo can run with only CPU by setting `device='cpu'` in all model nodes. For example:
+
+```python
+dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco.py',
+ model_checkpoint='https://download.openmmlab.com'
+ '/mmdetection/v2.0/ssd/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+ 'scratch_600e_coco_20210629_110627-974d9307.pth',
+ device='cpu',
+ input_buffer='_input_',
+ output_buffer='det_result')
+```
+
+- **Debug webcam and display**
+
+You can lanch the webcam runner with a debug config:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/examples/test_camera.py
+```
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/examples/pose_estimation.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/examples/pose_estimation.py
new file mode 100644
index 0000000000000000000000000000000000000000..471333a448530c5b99f9016729b269953099f466
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/examples/pose_estimation.py
@@ -0,0 +1,115 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+runner = dict(
+ # Basic configurations of the runner
+ name='Pose Estimation',
+ camera_id=0,
+ camera_fps=20,
+ synchronous=False,
+ # Define nodes.
+ # The configuration of a node usually includes:
+ # 1. 'type': Node class name
+ # 2. 'name': Node name
+ # 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the
+ # input and output buffer names. This may depend on the node class.
+ # 4. 'enable_key': assign a hot-key to toggle enable/disable this node.
+ # This may depend on the node class.
+ # 5. Other class-specific arguments
+ nodes=[
+ # 'DetectorNode':
+ # This node performs object detection from the frame image using an
+ # MMDetection model.
+ dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco.py',
+ model_checkpoint='https://download.openmmlab.com'
+ '/mmdetection/v2.0/ssd/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+ 'scratch_600e_coco_20210629_110627-974d9307.pth',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+ # 'TopDownPoseEstimatorNode':
+ # This node performs keypoint detection from the frame image using an
+ # MMPose top-down model. Detection results is needed.
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Human Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_mbv3_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://download.openmmlab.com/mmpose/top_down/'
+ 'vipnas/vipnas_mbv3_coco_wholebody_256x192_dark'
+ '-e2158108_20211205.pth',
+ cls_names=['person'],
+ input_buffer='det_result',
+ output_buffer='human_pose'),
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Animal Pose Estimator',
+ model_config='configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap'
+ '/animalpose/hrnet_w32_animalpose_256x256.py',
+ model_checkpoint='https://download.openmmlab.com/mmpose/animal/'
+ 'hrnet/hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth',
+ cls_names=['cat', 'dog', 'horse', 'sheep', 'cow'],
+ input_buffer='human_pose',
+ output_buffer='animal_pose'),
+ # 'ModelResultBindingNode':
+ # This node binds the latest model inference result with the current
+ # frame. (This means the frame image and inference result may be
+ # asynchronous).
+ dict(
+ type='ModelResultBindingNode',
+ name='ResultBinder',
+ frame_buffer='_frame_', # `_frame_` is a runner-reserved buffer
+ result_buffer='animal_pose',
+ output_buffer='frame'),
+ # 'PoseVisualizerNode':
+ # This node draw the pose visualization result in the frame image.
+ # Pose results is needed.
+ dict(
+ type='PoseVisualizerNode',
+ name='Visualizer',
+ enable_key='v',
+ frame_buffer='frame',
+ output_buffer='vis'),
+ # 'NoticeBoardNode':
+ # This node show a notice board with given content, e.g. help
+ # information.
+ dict(
+ type='NoticeBoardNode',
+ name='Helper',
+ enable_key='h',
+ enable=True,
+ frame_buffer='vis',
+ output_buffer='vis_notice',
+ content_lines=[
+ 'This is a demo for pose visualization and simple image '
+ 'effects. Have fun!', '', 'Hot-keys:',
+ '"v": Pose estimation result visualization',
+ '"s": Sunglasses effect B-)', '"b": Bug-eye effect 0_0',
+ '"h": Show help information',
+ '"m": Show diagnostic information', '"q": Exit'
+ ],
+ ),
+ # 'MonitorNode':
+ # This node show diagnostic information in the frame image. It can
+ # be used for debugging or monitoring system resource status.
+ dict(
+ type='MonitorNode',
+ name='Monitor',
+ enable_key='m',
+ enable=False,
+ frame_buffer='vis_notice',
+ output_buffer='display'),
+ # 'RecorderNode':
+ # This node save the output video into a file.
+ dict(
+ type='RecorderNode',
+ name='Recorder',
+ out_video_file='record.mp4',
+ frame_buffer='display',
+ output_buffer='_display_'
+ # `_display_` is a runner-reserved buffer
+ )
+ ])
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/examples/test_camera.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/examples/test_camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c1677f4f1cbe8fe3dad081c7b9889602a39956
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/examples/test_camera.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+runner = dict(
+ name='Debug CamRunner',
+ camera_id=0,
+ camera_fps=20,
+ nodes=[
+ dict(
+ type='MonitorNode',
+ name='Monitor',
+ enable_key='m',
+ frame_buffer='_frame_',
+ output_buffer='display'),
+ dict(
+ type='RecorderNode',
+ name='Recorder',
+ out_video_file='webcam_output.mp4',
+ frame_buffer='display',
+ output_buffer='_display_')
+ ])
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/eyes/README.md b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/eyes/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f9c37695eecb18a0e4becdbcc1aa59bde4e75247
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/eyes/README.md
@@ -0,0 +1,31 @@
+# Sunglasses and Bug-eye Effects
+
+We can apply fun effects on videos with pose estimation results, like adding sunglasses on the face, or make the eyes look bigger.
+
+
+
+
+
+## Instruction
+
+### Get started
+
+Launch the demo from the mmpose root directory:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/examples/pose_estimation.py
+```
+
+### Hotkeys
+
+| Hotkey | Function |
+| -- | -- |
+| s | Toggle the sunglasses effect on/off. |
+| b | Toggle the bug-eye effect on/off. |
+| h | Show help information. |
+| m | Show the monitoring information. |
+| q | Exit. |
+
+### Configuration
+
+See the [README](/tools/webcam/configs/examples/README.md#configuration) of pose estimation demo for model configurations.
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/eyes/eyes.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/eyes/eyes.py
new file mode 100644
index 0000000000000000000000000000000000000000..91bbfba9d9f89f7c7071375bedcc73a1e18d1783
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/eyes/eyes.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+runner = dict(
+ # Basic configurations of the runner
+ name='Eye Effects',
+ camera_id=0,
+ camera_fps=20,
+ synchronous=False,
+ # Define nodes.
+ # The configuration of a node usually includes:
+ # 1. 'type': Node class name
+ # 2. 'name': Node name
+ # 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the
+ # input and output buffer names. This may depend on the node class.
+ # 4. 'enable_key': assign a hot-key to toggle enable/disable this node.
+ # This may depend on the node class.
+ # 5. Other class-specific arguments
+ nodes=[
+ # 'DetectorNode':
+ # This node performs object detection from the frame image using an
+ # MMDetection model.
+ dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco.py',
+ model_checkpoint='https://download.openmmlab.com'
+ '/mmdetection/v2.0/ssd/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+ 'scratch_600e_coco_20210629_110627-974d9307.pth',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+ # 'TopDownPoseEstimatorNode':
+ # This node performs keypoint detection from the frame image using an
+ # MMPose top-down model. Detection results is needed.
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Human Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_mbv3_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://openmmlab-share.oss-cn-hangz'
+ 'hou.aliyuncs.com/mmpose/top_down/vipnas/vipnas_mbv3_co'
+ 'co_wholebody_256x192_dark-e2158108_20211205.pth',
+ cls_names=['person'],
+ input_buffer='det_result',
+ output_buffer='human_pose'),
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Animal Pose Estimator',
+ model_config='configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap'
+ '/animalpose/hrnet_w32_animalpose_256x256.py',
+ model_checkpoint='https://download.openmmlab.com/mmpose/animal/'
+ 'hrnet/hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth',
+ cls_names=['cat', 'dog', 'horse', 'sheep', 'cow'],
+ input_buffer='human_pose',
+ output_buffer='animal_pose'),
+ # 'ModelResultBindingNode':
+ # This node binds the latest model inference result with the current
+ # frame. (This means the frame image and inference result may be
+ # asynchronous).
+ dict(
+ type='ModelResultBindingNode',
+ name='ResultBinder',
+ frame_buffer='_frame_', # `_frame_` is a runner-reserved buffer
+ result_buffer='animal_pose',
+ output_buffer='frame'),
+ # 'SunglassesNode':
+ # This node draw the sunglasses effect in the frame image.
+ # Pose results is needed.
+ dict(
+ type='SunglassesNode',
+ name='Visualizer',
+ enable_key='s',
+ enable=True,
+ frame_buffer='frame',
+ output_buffer='vis_sunglasses'),
+ # 'BugEyeNode':
+ # This node draw the bug-eye effetc in the frame image.
+ # Pose results is needed.
+ dict(
+ type='BugEyeNode',
+ name='Visualizer',
+ enable_key='b',
+ enable=False,
+ frame_buffer='vis_sunglasses',
+ output_buffer='vis_bugeye'),
+ # 'NoticeBoardNode':
+ # This node show a notice board with given content, e.g. help
+ # information.
+ dict(
+ type='NoticeBoardNode',
+ name='Helper',
+ enable_key='h',
+ frame_buffer='vis_bugeye',
+ output_buffer='vis',
+ content_lines=[
+ 'This is a demo for pose visualization and simple image '
+ 'effects. Have fun!', '', 'Hot-keys:',
+ '"s": Sunglasses effect B-)', '"b": Bug-eye effect 0_0',
+ '"h": Show help information',
+ '"m": Show diagnostic information', '"q": Exit'
+ ],
+ ),
+ # 'MonitorNode':
+ # This node show diagnostic information in the frame image. It can
+ # be used for debugging or monitoring system resource status.
+ dict(
+ type='MonitorNode',
+ name='Monitor',
+ enable_key='m',
+ enable=False,
+ frame_buffer='vis',
+ output_buffer='_display_') # `_frame_` is a runner-reserved buffer
+ ])
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/face_swap/README.md b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/face_swap/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..02f4c8aa855702bf6a668970f8e7e071611caf8e
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/face_swap/README.md
@@ -0,0 +1,31 @@
+# Sunglasses and Bug-eye Effects
+
+Look! Where is my face?:eyes: And whose face is it?:laughing:
+
+
+
+
+
+## Instruction
+
+### Get started
+
+Launch the demo from the mmpose root directory:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/face_swap/face_swap.py
+```
+
+### Hotkeys
+
+| Hotkey | Function |
+| -- | -- |
+| s | Switch between modes
Shuffle: Randomly shuffle all faces
Clone: Choose one face and clone it for everyone
None: Nothing happens and everyone is safe :)
|
+| v | Toggle the pose visualization on/off. |
+| h | Show help information. |
+| m | Show diagnostic information. |
+| q | Exit. |
+
+### Configuration
+
+See the [README](/tools/webcam/configs/examples/README.md#configuration) of pose estimation demo for model configurations.
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/face_swap/face_swap.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/face_swap/face_swap.py
new file mode 100644
index 0000000000000000000000000000000000000000..403eaae4ace483d72a4baedbaf61072c24e3a1ec
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/face_swap/face_swap.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+runner = dict(
+ name='FaceSwap',
+ camera_id=0,
+ camera_fps=20,
+ synchronous=False,
+ nodes=[
+ dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco.py',
+ model_checkpoint='https://download.openmmlab.com'
+ '/mmdetection/v2.0/ssd/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+ 'scratch_600e_coco_20210629_110627-974d9307.pth',
+ device='cpu',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='TopDown Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_res50_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://openmmlab-share.oss-cn-hangzhou'
+ '.aliyuncs.com/mmpose/top_down/vipnas/'
+ 'vipnas_res50_wholebody_256x192_dark-67c0ce35_20211112.pth',
+ device='cpu',
+ cls_names=['person'],
+ input_buffer='det_result',
+ output_buffer='pose_result'),
+ dict(
+ type='ModelResultBindingNode',
+ name='ResultBinder',
+ frame_buffer='_frame_', # `_frame_` is a runner-reserved buffer
+ result_buffer='pose_result',
+ output_buffer='frame'),
+ dict(
+ type='FaceSwapNode',
+ name='FaceSwapper',
+ mode_key='s',
+ frame_buffer='frame',
+ output_buffer='face_swap'),
+ dict(
+ type='PoseVisualizerNode',
+ name='Visualizer',
+ enable_key='v',
+ frame_buffer='face_swap',
+ output_buffer='vis_pose'),
+ dict(
+ type='NoticeBoardNode',
+ name='Help Information',
+ enable_key='h',
+ content_lines=[
+ 'Swap your faces! ',
+ 'Hot-keys:',
+ '"v": Toggle the pose visualization on/off.',
+ '"s": Switch between modes: Shuffle, Clone and None',
+ '"h": Show help information',
+ '"m": Show diagnostic information',
+ '"q": Exit',
+ ],
+ frame_buffer='vis_pose',
+ output_buffer='vis_notice'),
+ dict(
+ type='MonitorNode',
+ name='Monitor',
+ enable_key='m',
+ enable=False,
+ frame_buffer='vis_notice',
+ output_buffer='display'),
+ dict(
+ type='RecorderNode',
+ name='Recorder',
+ out_video_file='faceswap_output.mp4',
+ frame_buffer='display',
+ output_buffer='_display_')
+ ])
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/meow_dwen_dwen/README.md b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/meow_dwen_dwen/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..997ffc174bd70c2de6a22edee53f5b52275ae187
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/meow_dwen_dwen/README.md
@@ -0,0 +1,44 @@
+# Meow Dwen Dwen
+
+Do you know [Bing DwenDwen (冰墩墩)](https://en.wikipedia.org/wiki/Bing_Dwen_Dwen_and_Shuey_Rhon_Rhon), the mascot of 2022 Beijing Olympic Games?
+
+
+
+
+
+Now you can dress your cat up in this costume and TA-DA! Be prepared for super cute **Meow Dwen Dwen**.
+
+
+
+
+
+You are a dog fan? Hold on, here comes Woof Dwen Dwen.
+
+
+
+
+
+## Instruction
+
+### Get started
+
+Launch the demo from the mmpose root directory:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/meow_dwen_dwen/meow_dwen_dwen.py
+```
+
+### Hotkeys
+
+| Hotkey | Function |
+| -- | -- |
+| s | Change the background. |
+| h | Show help information. |
+| m | Show diagnostic information. |
+| q | Exit. |
+
+### Configuration
+
+- **Use video input**
+
+As you can see in the config, we set `camera_id` as the path of the input image. You can also set it as a video file path (or url), or a webcam ID number (e.g. `camera_id=0`), to capture the dynamic face from the video input.
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/meow_dwen_dwen/meow_dwen_dwen.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/meow_dwen_dwen/meow_dwen_dwen.py
new file mode 100644
index 0000000000000000000000000000000000000000..399d01cf7c8df103772913294f1c0612979330e6
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/meow_dwen_dwen/meow_dwen_dwen.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+runner = dict(
+ # Basic configurations of the runner
+ name='Little fans of 2022 Beijing Winter Olympics',
+ # Cat image
+ camera_id='https://user-images.githubusercontent.com/'
+ '15977946/152932036-b5554cf8-24cf-40d6-a358-35a106013f11.jpeg',
+ # Dog image
+ # camera_id='https://user-images.githubusercontent.com/'
+ # '15977946/152932051-cd280b35-8066-45a0-8f52-657c8631aaba.jpg',
+ camera_fps=20,
+ nodes=[
+ dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco.py',
+ model_checkpoint='https://download.openmmlab.com'
+ '/mmdetection/v2.0/ssd/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+ 'scratch_600e_coco_20210629_110627-974d9307.pth',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Animal Pose Estimator',
+ model_config='configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap'
+ '/ap10k/hrnet_w32_ap10k_256x256.py',
+ model_checkpoint='https://download.openmmlab.com/mmpose/animal/'
+ 'hrnet/hrnet_w32_ap10k_256x256-18aac840_20211029.pth',
+ cls_names=['cat', 'dog'],
+ input_buffer='det_result',
+ output_buffer='animal_pose'),
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='TopDown Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_res50_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://openmmlab-share.oss-cn-hangzhou'
+ '.aliyuncs.com/mmpose/top_down/vipnas/'
+ 'vipnas_res50_wholebody_256x192_dark-67c0ce35_20211112.pth',
+ device='cpu',
+ cls_names=['person'],
+ input_buffer='animal_pose',
+ output_buffer='human_pose'),
+ dict(
+ type='ModelResultBindingNode',
+ name='ResultBinder',
+ frame_buffer='_frame_', # `_frame_` is a runner-reserved buffer
+ result_buffer='human_pose',
+ output_buffer='frame'),
+ dict(
+ type='XDwenDwenNode',
+ name='XDwenDwen',
+ mode_key='s',
+ resource_file='tools/webcam/configs/meow_dwen_dwen/'
+ 'resource-info.json',
+ out_shape=(480, 480),
+ frame_buffer='frame',
+ output_buffer='vis'),
+ dict(
+ type='NoticeBoardNode',
+ name='Helper',
+ enable_key='h',
+ enable=False,
+ frame_buffer='vis',
+ output_buffer='vis_notice',
+ content_lines=[
+ 'Let your pet put on a costume of Bing-Dwen-Dwen, '
+ 'the mascot of 2022 Beijing Winter Olympics. Have fun!', '',
+ 'Hot-keys:', '"s": Change the background',
+ '"h": Show help information',
+ '"m": Show diagnostic information', '"q": Exit'
+ ],
+ ),
+ dict(
+ type='MonitorNode',
+ name='Monitor',
+ enable_key='m',
+ enable=False,
+ frame_buffer='vis_notice',
+ output_buffer='display'),
+ dict(
+ type='RecorderNode',
+ name='Recorder',
+ out_video_file='record.mp4',
+ frame_buffer='display',
+ output_buffer='_display_'
+ # `_display_` is a runner-reserved buffer
+ )
+ ])
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/meow_dwen_dwen/resource-info.json b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/meow_dwen_dwen/resource-info.json
new file mode 100644
index 0000000000000000000000000000000000000000..adb811cc7f3eafea56ff4d3f577ec28e33e80f0a
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/meow_dwen_dwen/resource-info.json
@@ -0,0 +1,26 @@
+[
+ {
+ "id": 1,
+ "result": "{\"width\":690,\"height\":713,\"valid\":true,\"rotate\":0,\"step_1\":{\"toolName\":\"pointTool\",\"result\":[{\"x\":374.86387434554973,\"y\":262.8020942408377,\"attribute\":\"\",\"valid\":true,\"id\":\"8SK9cVyu\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":2},{\"x\":492.8261780104712,\"y\":285.2,\"attribute\":\"\",\"valid\":true,\"id\":\"qDk54WsI\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":1},{\"x\":430.11204188481673,\"y\":318.0502617801047,\"attribute\":\"\",\"valid\":true,\"id\":\"4H80L7lL\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":3}]},\"step_2\":{\"dataSourceStep\":0,\"toolName\":\"polygonTool\",\"result\":[{\"id\":\"pwUsrf9u\",\"sourceID\":\"\",\"valid\":true,\"textAttribute\":\"\",\"pointList\":[{\"x\":423.3926701570681,\"y\":191.87539267015708},{\"x\":488.3465968586388,\"y\":209.04712041884818},{\"x\":535.3821989528797,\"y\":248.6167539267016},{\"x\":549.5675392670157,\"y\":306.8513089005236},{\"x\":537.6219895287959,\"y\":349.407329842932},{\"x\":510.74450261780106,\"y\":381.51099476439794},{\"x\":480.1340314136126,\"y\":394.9497382198953},{\"x\":411.4471204188482,\"y\":390.47015706806286},{\"x\":355.45235602094243,\"y\":373.29842931937173},{\"x\":306.17696335078534,\"y\":327.00942408376966},{\"x\":294.97801047120424,\"y\":284.45340314136126},{\"x\":306.9235602094241,\"y\":245.6303664921466},{\"x\":333.8010471204189,\"y\":217.25968586387435},{\"x\":370.3842931937173,\"y\":196.35497382198955}],\"attribute\":\"\",\"order\":1}]}}",
+ "url": "https://user-images.githubusercontent.com/15977946/152742677-35fe8a01-bd06-4a12-a02e-949e7d71f28a.jpg",
+ "fileName": "bing_dwen_dwen1.jpg"
+ },
+ {
+ "id": 2,
+ "result": "{\"width\":690,\"height\":659,\"valid\":true,\"rotate\":0,\"step_1\":{\"dataSourceStep\":0,\"toolName\":\"pointTool\",\"result\":[{\"x\":293.2460732984293,\"y\":242.89842931937173,\"attribute\":\"\",\"valid\":true,\"id\":\"KgPs39bY\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":1},{\"x\":170.41675392670155,\"y\":270.50052356020944,\"attribute\":\"\",\"valid\":true,\"id\":\"XwHyoBFU\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":2},{\"x\":224.24083769633506,\"y\":308.45340314136126,\"attribute\":\"\",\"valid\":true,\"id\":\"Qfs4YfuB\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":3}]},\"step_2\":{\"dataSourceStep\":0,\"toolName\":\"polygonTool\",\"result\":[{\"id\":\"ts5jlJxb\",\"sourceID\":\"\",\"valid\":true,\"textAttribute\":\"\",\"pointList\":[{\"x\":178.69738219895285,\"y\":184.93403141361256},{\"x\":204.91937172774865,\"y\":172.5130890052356},{\"x\":252.5329842931937,\"y\":169.0628272251309},{\"x\":295.3162303664921,\"y\":175.27329842931937},{\"x\":333.95916230366487,\"y\":195.2848167539267},{\"x\":360.18115183246067,\"y\":220.1267015706806},{\"x\":376.0523560209424,\"y\":262.909947643979},{\"x\":373.98219895287957,\"y\":296.0324607329843},{\"x\":344.99999999999994,\"y\":335.365445026178},{\"x\":322.22827225130885,\"y\":355.37696335078533},{\"x\":272.544502617801,\"y\":378.1486910994764},{\"x\":221.48062827225127,\"y\":386.42931937172773},{\"x\":187.6680628272251,\"y\":385.7392670157068},{\"x\":158.68586387434553,\"y\":369.1780104712042},{\"x\":137.98429319371724,\"y\":337.43560209424083},{\"x\":127.63350785340312,\"y\":295.34240837696336},{\"x\":131.0837696335078,\"y\":242.89842931937173},{\"x\":147.64502617801045,\"y\":208.3958115183246}],\"attribute\":\"\",\"order\":1}]}}",
+ "url": "https://user-images.githubusercontent.com/15977946/152742707-c0c51844-e1d0-42d0-9a12-e369002e082f.jpg",
+ "fileName": "bing_dwen_dwen2.jpg"
+ },
+ {
+ "id": 3,
+ "result": "{\"width\":690,\"height\":811,\"valid\":true,\"rotate\":0,\"step_1\":{\"dataSourceStep\":0,\"toolName\":\"pointTool\",\"result\":[{\"x\":361.13507853403144,\"y\":300.62198952879584,\"attribute\":\"\",\"valid\":true,\"id\":\"uAtbXtf2\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":1},{\"x\":242.24502617801048,\"y\":317.60628272251313,\"attribute\":\"\",\"valid\":true,\"id\":\"iLtceHMA\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":2},{\"x\":302.5392670157068,\"y\":356.67015706806285,\"attribute\":\"\",\"valid\":true,\"id\":\"n9MTlJ6A\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":3}]},\"step_2\":{\"dataSourceStep\":0,\"toolName\":\"polygonTool\",\"result\":[{\"id\":\"5sTLU5wF\",\"sourceID\":\"\",\"valid\":true,\"textAttribute\":\"\",\"pointList\":[{\"x\":227.80837696335078,\"y\":247.12146596858642},{\"x\":248.18952879581153,\"y\":235.23246073298432},{\"x\":291.4994764397906,\"y\":225.04188481675394},{\"x\":351.7937172774869,\"y\":229.28795811518327},{\"x\":393.40523560209425,\"y\":245.42303664921468},{\"x\":424.8261780104712,\"y\":272.59790575916236},{\"x\":443.5089005235602,\"y\":298.07434554973827},{\"x\":436.7151832460733,\"y\":345.6303664921466},{\"x\":406.1434554973822,\"y\":382.9958115183247},{\"x\":355.1905759162304,\"y\":408.4722513089006},{\"x\":313.57905759162304,\"y\":419.5120418848168},{\"x\":262.6261780104712,\"y\":417.81361256544506},{\"x\":224.41151832460733,\"y\":399.9801047120419},{\"x\":201.48272251308902,\"y\":364.3130890052356},{\"x\":194.68900523560208,\"y\":315.0586387434555},{\"x\":202.33193717277487,\"y\":272.59790575916236}],\"attribute\":\"\",\"order\":1}]}}",
+ "url": "https://user-images.githubusercontent.com/15977946/152742728-99392ecf-8f5c-46cf-b5c4-fe7fb6b39976.jpg",
+ "fileName": "bing_dwen_dwen3.jpg"
+ },
+ {
+ "id": 4,
+ "result": "{\"width\":690,\"height\":690,\"valid\":true,\"rotate\":0,\"step_1\":{\"dataSourceStep\":0,\"toolName\":\"pointTool\",\"result\":[{\"x\":365.9528795811519,\"y\":464.5759162303665,\"attribute\":\"\",\"valid\":true,\"id\":\"IKprTuHS\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":1},{\"x\":470.71727748691103,\"y\":445.06806282722516,\"attribute\":\"\",\"valid\":true,\"id\":\"Z90CWkEI\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":2},{\"x\":410.74869109947645,\"y\":395.2146596858639,\"attribute\":\"\",\"valid\":true,\"id\":\"UWRstKZk\",\"sourceID\":\"\",\"textAttribute\":\"\",\"order\":3}]},\"step_2\":{\"dataSourceStep\":0,\"toolName\":\"polygonTool\",\"result\":[{\"id\":\"C30Pc9Ww\",\"sourceID\":\"\",\"valid\":true,\"textAttribute\":\"\",\"pointList\":[{\"x\":412.91623036649213,\"y\":325.85340314136124},{\"x\":468.5497382198953,\"y\":335.9685863874345},{\"x\":501.78534031413614,\"y\":369.2041884816754},{\"x\":514.0680628272252,\"y\":415.44502617801044},{\"x\":504.67539267015707,\"y\":472.5235602094241},{\"x\":484.44502617801044,\"y\":497.0890052356021},{\"x\":443.26178010471205,\"y\":512.9842931937172},{\"x\":389.7958115183246,\"y\":518.7643979057591},{\"x\":336.32984293193715,\"y\":504.31413612565444},{\"x\":302.3717277486911,\"y\":462.40837696335075},{\"x\":298.0366492146597,\"y\":416.89005235602093},{\"x\":318.26701570680626,\"y\":372.0942408376963},{\"x\":363.0628272251309,\"y\":341.0261780104712}],\"attribute\":\"\",\"order\":1}]}}",
+ "url": "https://user-images.githubusercontent.com/15977946/152742755-9dc75f89-4156-4103-9c6d-f35f1f409d11.jpg",
+ "fileName": "bing_dwen_dwen4.jpg"
+ }
+]
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/newyear/README.md b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/newyear/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8c655c121e236146a00a378b5bf495dbf24e6888
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/newyear/README.md
@@ -0,0 +1,31 @@
+# New Year Hat and Firecracker Effects
+
+This demo provides new year effects with pose estimation results, like adding hat on the head and firecracker in the hands.
+
+
+
+
+
+## Instruction
+
+### Get started
+
+Launch the demo from the mmpose root directory:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/newyear/new_year.py
+```
+
+### Hotkeys
+
+| Hotkey | Function |
+| -- | -- |
+| t | Toggle the hat effect on/off. |
+| f | Toggle the firecracker effect on/off. |
+| h | Show help information. |
+| m | Show the monitoring information. |
+| q | Exit. |
+
+### Configuration
+
+See the [README](/tools/webcam/configs/examples/README.md#configuration) of pose estimation demo for model configurations.
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/newyear/new_year.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/newyear/new_year.py
new file mode 100644
index 0000000000000000000000000000000000000000..3551184053312da288ccac95ae9f37e7f116dd1b
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/newyear/new_year.py
@@ -0,0 +1,122 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+runner = dict(
+ # Basic configurations of the runner
+ name='Pose Estimation',
+ camera_id=0,
+ camera_fps=20,
+ synchronous=False,
+ # Define nodes.
+ # The configuration of a node usually includes:
+ # 1. 'type': Node class name
+ # 2. 'name': Node name
+ # 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the
+ # input and output buffer names. This may depend on the node class.
+ # 4. 'enable_key': assign a hot-key to toggle enable/disable this node.
+ # This may depend on the node class.
+ # 5. Other class-specific arguments
+ nodes=[
+ # 'DetectorNode':
+ # This node performs object detection from the frame image using an
+ # MMDetection model.
+ dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco.py',
+ model_checkpoint='https://download.openmmlab.com'
+ '/mmdetection/v2.0/ssd/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+ 'scratch_600e_coco_20210629_110627-974d9307.pth',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+ # 'TopDownPoseEstimatorNode':
+ # This node performs keypoint detection from the frame image using an
+ # MMPose top-down model. Detection results is needed.
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Human Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_mbv3_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://openmmlab-share.oss-cn-hangz'
+ 'hou.aliyuncs.com/mmpose/top_down/vipnas/vipnas_mbv3_co'
+ 'co_wholebody_256x192_dark-e2158108_20211205.pth',
+ cls_names=['person'],
+ input_buffer='det_result',
+ output_buffer='human_pose'),
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Animal Pose Estimator',
+ model_config='configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap'
+ '/animalpose/hrnet_w32_animalpose_256x256.py',
+ model_checkpoint='https://download.openmmlab.com/mmpose/animal/'
+ 'hrnet/hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth',
+ cls_names=['cat', 'dog', 'horse', 'sheep', 'cow'],
+ input_buffer='human_pose',
+ output_buffer='animal_pose'),
+ # 'ModelResultBindingNode':
+ # This node binds the latest model inference result with the current
+ # frame. (This means the frame image and inference result may be
+ # asynchronous).
+ dict(
+ type='ModelResultBindingNode',
+ name='ResultBinder',
+ frame_buffer='_frame_', # `_frame_` is a runner-reserved buffer
+ result_buffer='animal_pose',
+ output_buffer='frame'),
+ # 'HatNode':
+ # This node draw the hat effect in the frame image.
+ # Pose results is needed.
+ dict(
+ type='HatNode',
+ name='Visualizer',
+ enable_key='t',
+ frame_buffer='frame',
+ output_buffer='vis_hat'),
+ # 'FirecrackerNode':
+ # This node draw the firecracker effect in the frame image.
+ # Pose results is needed.
+ dict(
+ type='FirecrackerNode',
+ name='Visualizer',
+ enable_key='f',
+ frame_buffer='vis_hat',
+ output_buffer='vis_firecracker'),
+ # 'NoticeBoardNode':
+ # This node show a notice board with given content, e.g. help
+ # information.
+ dict(
+ type='NoticeBoardNode',
+ name='Helper',
+ enable_key='h',
+ enable=True,
+ frame_buffer='vis_firecracker',
+ output_buffer='vis_notice',
+ content_lines=[
+ 'This is a demo for pose visualization and simple image '
+ 'effects. Have fun!', '', 'Hot-keys:', '"t": Hat effect',
+ '"f": Firecracker effect', '"h": Show help information',
+ '"m": Show diagnostic information', '"q": Exit'
+ ],
+ ),
+ # 'MonitorNode':
+ # This node show diagnostic information in the frame image. It can
+ # be used for debugging or monitoring system resource status.
+ dict(
+ type='MonitorNode',
+ name='Monitor',
+ enable_key='m',
+ enable=False,
+ frame_buffer='vis_notice',
+ output_buffer='display'),
+ # 'RecorderNode':
+ # This node save the output video into a file.
+ dict(
+ type='RecorderNode',
+ name='Recorder',
+ out_video_file='record.mp4',
+ frame_buffer='display',
+ output_buffer='_display_'
+ # `_display_` is a runner-reserved buffer
+ )
+ ])
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/supersaiyan/README.md b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/supersaiyan/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9e9aef1bbaa7c62277a039cfad995a01e0491a10
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/supersaiyan/README.md
@@ -0,0 +1,96 @@
+# Super Saiyan Effects
+
+We can apply fun effects on videos with pose estimation results, like Super Saiyan transformation.
+
+https://user-images.githubusercontent.com/11788150/150138076-2192079f-068a-4d43-bf27-2f1fd708cabc.mp4
+
+## Instruction
+
+### Get started
+
+Launch the demo from the mmpose root directory:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/supersaiyan/saiyan.py
+```
+
+### Hotkeys
+
+| Hotkey | Function |
+| -- | -- |
+| s | Toggle the Super Saiyan effect on/off. |
+| h | Show help information. |
+| m | Show the monitoring information. |
+| q | Exit. |
+
+Note that the demo will automatically save the output video into a file `record.mp4`.
+
+### Configuration
+
+- **Choose a detection model**
+
+Users can choose detection models from the [MMDetection Model Zoo](https://mmdetection.readthedocs.io/en/v2.20.0/model_zoo.html). Just set the `model_config` and `model_checkpoint` in the detector node accordingly, and the model will be automatically downloaded and loaded.
+
+```python
+# 'DetectorNode':
+# This node performs object detection from the frame image using an
+# MMDetection model.
+dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/mask_rcnn_r50_fpn_2x_coco.py',
+ model_checkpoint='https://download.openmmlab.com/'
+ 'mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/'
+ 'mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392'
+ '__segm_mAP-0.354_20200505_003907-3e542a40.pth',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+```
+
+- **Choose a or more pose models**
+
+In this demo we use two [top-down](https://github.com/open-mmlab/mmpose/tree/master/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap) pose estimation models for humans and animals respectively. Users can choose models from the [MMPose Model Zoo](https://mmpose.readthedocs.io/en/latest/modelzoo.html). To apply different pose models on different instance types, you can add multiple pose estimator nodes with `cls_names` set accordingly.
+
+```python
+# 'TopDownPoseEstimatorNode':
+# This node performs keypoint detection from the frame image using an
+# MMPose top-down model. Detection results is needed.
+dict(
+ type='TopDownPoseEstimatorNode',
+ name='Human Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_mbv3_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://openmmlab-share.oss-cn-hangz'
+ 'hou.aliyuncs.com/mmpose/top_down/vipnas/vipnas_mbv3_co'
+ 'co_wholebody_256x192_dark-e2158108_20211205.pth',
+ cls_names=['person'],
+ input_buffer='det_result',
+ output_buffer='human_pose')
+```
+
+- **Run the demo without GPU**
+
+If you don't have GPU and CUDA in your device, the demo can run with only CPU by setting `device='cpu'` in all model nodes. For example:
+
+```python
+dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/mask_rcnn_r50_fpn_2x_coco.py',
+ model_checkpoint='https://download.openmmlab.com/'
+ 'mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/'
+ 'mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392'
+ '__segm_mAP-0.354_20200505_003907-3e542a40.pth',
+ device='cpu',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+```
+
+- **Debug webcam and display**
+
+You can launch the webcam runner with a debug config:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/examples/test_camera.py
+```
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/supersaiyan/saiyan.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/supersaiyan/saiyan.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a8e7bc82c7ca53fb6a0350ce8b0bd3e3ac6e737
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/supersaiyan/saiyan.py
@@ -0,0 +1,93 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+runner = dict(
+ # Basic configurations of the runner
+ name='Super Saiyan Effects',
+ camera_id=0,
+ camera_fps=30,
+ synchronous=False,
+ # Define nodes.
+ # The configuration of a node usually includes:
+ # 1. 'type': Node class name
+ # 2. 'name': Node name
+ # 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the
+ # input and output buffer names. This may depend on the node class.
+ # 4. 'enable_key': assign a hot-key to toggle enable/disable this node.
+ # This may depend on the node class.
+ # 5. Other class-specific arguments
+ nodes=[
+ # 'DetectorNode':
+ # This node performs object detection from the frame image using an
+ # MMDetection model.
+ dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/mask_rcnn_r50_fpn_2x_coco.py',
+ model_checkpoint='https://download.openmmlab.com/'
+ 'mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_2x_coco/'
+ 'mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392'
+ '__segm_mAP-0.354_20200505_003907-3e542a40.pth',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+ # 'TopDownPoseEstimatorNode':
+ # This node performs keypoint detection from the frame image using an
+ # MMPose top-down model. Detection results is needed.
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Human Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_mbv3_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://openmmlab-share.oss-cn-hangz'
+ 'hou.aliyuncs.com/mmpose/top_down/vipnas/vipnas_mbv3_co'
+ 'co_wholebody_256x192_dark-e2158108_20211205.pth',
+ cls_names=['person'],
+ input_buffer='det_result',
+ output_buffer='human_pose'),
+ # 'ModelResultBindingNode':
+ # This node binds the latest model inference result with the current
+ # frame. (This means the frame image and inference result may be
+ # asynchronous).
+ dict(
+ type='ModelResultBindingNode',
+ name='ResultBinder',
+ frame_buffer='_frame_', # `_frame_` is a runner-reserved buffer
+ result_buffer='human_pose',
+ output_buffer='frame'),
+ # 'SaiyanNode':
+ # This node draw the Super Saiyan effect in the frame image.
+ # Pose results is needed.
+ dict(
+ type='SaiyanNode',
+ name='Visualizer',
+ enable_key='s',
+ cls_names=['person'],
+ enable=True,
+ frame_buffer='frame',
+ output_buffer='vis_saiyan'),
+ # 'NoticeBoardNode':
+ # This node show a notice board with given content, e.g. help
+ # information.
+ dict(
+ type='NoticeBoardNode',
+ name='Helper',
+ enable_key='h',
+ frame_buffer='vis_saiyan',
+ output_buffer='vis',
+ content_lines=[
+ 'This is a demo for super saiyan effects. Have fun!', '',
+ 'Hot-keys:', '"s": Saiyan effect',
+ '"h": Show help information',
+ '"m": Show diagnostic information', '"q": Exit'
+ ],
+ ),
+ # 'MonitorNode':
+ # This node show diagnostic information in the frame image. It can
+ # be used for debugging or monitoring system resource status.
+ dict(
+ type='MonitorNode',
+ name='Monitor',
+ enable_key='m',
+ enable=False,
+ frame_buffer='vis',
+ output_buffer='_display_') # `_frame_` is a runner-reserved buffer
+ ])
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/valentinemagic/README.md b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/valentinemagic/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8063d2e18640a4312167ed1c022fce3cf613937e
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/valentinemagic/README.md
@@ -0,0 +1,35 @@
+# Valentine Magic
+
+Do you want to show your **love** to your beloved one, especially on Valentine's Day? Express it with your pose using MMPose right away and see the Valentine Magic!
+
+Try to pose a hand heart gesture, and see what will happen?
+
+Prefer a blow kiss? Here comes your flying heart~
+
+
+
+
+
+## Instruction
+
+### Get started
+
+Launch the demo from the mmpose root directory:
+
+```shell
+python tools/webcam/run_webcam.py --config tools/webcam/configs/valentinemagic/valentinemagic.py
+```
+
+### Hotkeys
+
+| Hotkey | Function |
+| -- | -- |
+| l | Toggle the Valentine Magic effect on/off. |
+| v | Toggle the pose visualization on/off. |
+| h | Show help information. |
+| m | Show diagnostic information. |
+| q | Exit. |
+
+### Configuration
+
+See the [README](/tools/webcam/configs/examples/README.md#configuration) of pose estimation demo for model configurations.
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/valentinemagic/valentinemagic.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/valentinemagic/valentinemagic.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f921b07901805b490be264c28e12c7de3648f8b
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/configs/valentinemagic/valentinemagic.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+runner = dict(
+ # Basic configurations of the runner
+ name='Human Pose and Effects',
+ camera_id=0,
+ camera_fps=30,
+
+ # Define nodes.
+ #
+ # The configuration of a node usually includes:
+ # 1. 'type': Node class name
+ # 2. 'name': Node name
+ # 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the
+ # input and output buffer names. This may depend on the node class.
+ # 4. 'enable_key': assign a hot-key to toggle enable/disable this node.
+ # This may depend on the node class.
+ # 5. Other class-specific arguments
+ nodes=[
+ # 'DetectorNode':
+ # This node performs object detection from the frame image using an
+ # MMDetection model.
+ dict(
+ type='DetectorNode',
+ name='Detector',
+ model_config='demo/mmdetection_cfg/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco.py',
+ model_checkpoint='https://download.openmmlab.com'
+ '/mmdetection/v2.0/ssd/'
+ 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+ 'scratch_600e_coco_20210629_110627-974d9307.pth',
+ input_buffer='_input_', # `_input_` is a runner-reserved buffer
+ output_buffer='det_result'),
+ # 'TopDownPoseEstimatorNode':
+ # This node performs keypoint detection from the frame image using an
+ # MMPose top-down model. Detection results is needed.
+ dict(
+ type='TopDownPoseEstimatorNode',
+ name='Human Pose Estimator',
+ model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ 'topdown_heatmap/coco-wholebody/'
+ 'vipnas_mbv3_coco_wholebody_256x192_dark.py',
+ model_checkpoint='https://download.openmmlab.com/mmpose/top_down/'
+ 'vipnas/vipnas_mbv3_coco_wholebody_256x192_dark'
+ '-e2158108_20211205.pth',
+ cls_names=['person'],
+ input_buffer='det_result',
+ output_buffer='pose_result'),
+ # 'ModelResultBindingNode':
+ # This node binds the latest model inference result with the current
+ # frame. (This means the frame image and inference result may be
+ # asynchronous).
+ dict(
+ type='ModelResultBindingNode',
+ name='ResultBinder',
+ frame_buffer='_frame_', # `_frame_` is a runner-reserved buffer
+ result_buffer='pose_result',
+ output_buffer='frame'),
+ # 'PoseVisualizerNode':
+ # This node draw the pose visualization result in the frame image.
+ # Pose results is needed.
+ dict(
+ type='PoseVisualizerNode',
+ name='Visualizer',
+ enable_key='v',
+ enable=False,
+ frame_buffer='frame',
+ output_buffer='vis'),
+ # 'ValentineMagicNode':
+ # This node draw heart in the image.
+ # It can launch dynamically expanding heart from the middle of
+ # hands if the persons pose a "hand heart" gesture or blow a kiss.
+ # Only there are two persons in the image can trigger this effect.
+ # Pose results is needed.
+ dict(
+ type='ValentineMagicNode',
+ name='Visualizer',
+ enable_key='l',
+ frame_buffer='vis',
+ output_buffer='vis_heart',
+ ),
+ # 'NoticeBoardNode':
+ # This node show a notice board with given content, e.g. help
+ # information.
+ dict(
+ type='NoticeBoardNode',
+ name='Helper',
+ enable_key='h',
+ enable=False,
+ frame_buffer='vis_heart',
+ output_buffer='vis_notice',
+ content_lines=[
+ 'This is a demo for pose visualization and simple image '
+ 'effects. Have fun!', '', 'Hot-keys:',
+ '"h": Show help information', '"l": LoveHeart Effect',
+ '"v": PoseVisualizer', '"m": Show diagnostic information',
+ '"q": Exit'
+ ],
+ ),
+ # 'MonitorNode':
+ # This node show diagnostic information in the frame image. It can
+ # be used for debugging or monitoring system resource status.
+ dict(
+ type='MonitorNode',
+ name='Monitor',
+ enable_key='m',
+ enable=False,
+ frame_buffer='vis_notice',
+ output_buffer='display'), # `_frame_` is a runner-reserved buffer
+ # 'RecorderNode':
+ # This node record the frames into a local file. It can save the
+ # visualiztion results. Uncommit the following lines to turn it on.
+ dict(
+ type='RecorderNode',
+ name='Recorder',
+ out_video_file='record.mp4',
+ frame_buffer='display',
+ output_buffer='_display_')
+ ])
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/docs/example_cn.md b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/docs/example_cn.md
new file mode 100644
index 0000000000000000000000000000000000000000..69b9898c3237ab6c81b6af28dfcb50224ac424df
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/docs/example_cn.md
@@ -0,0 +1,171 @@
+# 开发示例:给猫咪戴上太阳镜
+
+## 设计思路
+
+在动手之前,我们先考虑如何实现这个功能:
+
+- 首先,要做目标检测,找到图像中的猫咪
+- 接着,要估计猫咪的关键点位置,比如左右眼的位置
+- 最后,把太阳镜素材图片贴在合适的位置,TA-DA!
+
+按照这个思路,下面我们来看如何一步一步实现它。
+
+## Step 1:从一个现成的 Config 开始
+
+在 WebcamAPI 中,已经添加了一些实现常用功能的 Node,并提供了对应的 config 示例。利用这些可以减少用户的开发量。例如,我们可以以上面的姿态估计 demo 为基础。它的 config 位于 `tools/webcam/configs/example/pose_estimation.py`。为了更直观,我们把这个 config 中的功能节点表示成以下流程图:
+
+
+
+图中的每个 Data Buffer 就是一个用来存放数据的容器。用户不需要关注 buffer 的具体细节,只需要将其简单理解成 Node 输入输出的名字即可。用户在 config 中可以任意定义这些名字,不过要注意有以下几个特殊的名字:
+
+- _input_:存放 runner 读入的视频帧,用于模型推理
+- _frame_ :存放 runner 读入的视频帧,用于可视化
+- _display_:存放经过所以 Node 处理后的结果,用于在屏幕上显示
+
+当一帧视频数据被 runner 读入后,会被放进 _input_ 和 _frame_ 两个 buffer 中,然后按照 config 中定义的 Node 连接关系依次通过各个 Node ,最终到达 _display_ ,并被 runner 读出显示在屏幕上。
+
+#### Get Advanced: 关于 buffer
+
+- Buffer 本质是一个有限长度的队列,在 runner 中会包含一个 BufferManager 实例(见`mmpose/tools/webcam/webcam_apis/buffer.py')来生成和管理所有 buffer。Node 会按照 config 从对应的 buffer 中读出或写入数据。
+- 当一个 buffer 已满(达到最大长度)时,写入数据的操作通常不会被 block,而是会将 buffer 中已有的最早一条数据“挤出去”。
+- 为什么有_input_和_frame_两个输入呢?因为有些 Node 的操作较为耗时(如目标检测,姿态估计等需要模型推理的 Node)。为了保证显示的流畅,我们通常用_input_来作为这类耗时较大的操作的输入,而用_frame_来实时绘制可视化的结果。因为各个节点是异步运行的,这样就可以保证可视化的实时和流畅。
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/run_webcam.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/run_webcam.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce8d92e78e385d5bfaf2782cfc5b9d627531d20b
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/run_webcam.py
@@ -0,0 +1,38 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from argparse import ArgumentParser
+
+from mmcv import Config, DictAction
+from webcam_apis import WebcamRunner
+
+
+def parse_args():
+ parser = ArgumentParser('Lauch webcam runner')
+ parser.add_argument(
+ '--config',
+ type=str,
+ default='tools/webcam/configs/meow_dwen_dwen/meow_dwen_dwen.py')
+
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ default={},
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. For example, '
+ "'--cfg-options runner.camera_id=1 runner.synchronous=True'")
+
+ return parser.parse_args()
+
+
+def launch():
+ args = parse_args()
+ cfg = Config.fromfile(args.config)
+ cfg.merge_from_dict(args.cfg_options)
+
+ runner = WebcamRunner(**cfg.runner)
+ runner.run()
+
+
+if __name__ == '__main__':
+ launch()
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/__init__.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c8a2f5e0f6bf8d3c1b3d766dbe7a7d2c69cfaa4
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .webcam_runner import WebcamRunner
+
+__all__ = ['WebcamRunner']
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/__init__.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a882030b4a1b5aac87206e84fe69041bcd83035f
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import NODES
+from .faceswap_node import FaceSwapNode
+from .frame_effect_node import (BackgroundNode, BugEyeNode, MoustacheNode,
+ NoticeBoardNode, PoseVisualizerNode,
+ SaiyanNode, SunglassesNode)
+from .helper_node import ModelResultBindingNode, MonitorNode, RecorderNode
+from .mmdet_node import DetectorNode
+from .mmpose_node import TopDownPoseEstimatorNode
+from .valentinemagic_node import ValentineMagicNode
+from .xdwendwen_node import XDwenDwenNode
+
+__all__ = [
+ 'NODES', 'PoseVisualizerNode', 'DetectorNode', 'TopDownPoseEstimatorNode',
+ 'MonitorNode', 'BugEyeNode', 'SunglassesNode', 'ModelResultBindingNode',
+ 'NoticeBoardNode', 'RecorderNode', 'FaceSwapNode', 'MoustacheNode',
+ 'SaiyanNode', 'BackgroundNode', 'XDwenDwenNode', 'ValentineMagicNode'
+]
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/builder.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..44900b7efdc9822e693ce572cca16dafda388640
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/builder.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import Registry
+
+NODES = Registry('node')
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/faceswap_node.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/faceswap_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac44207fc363680aef49cfa1ea2b77707682484
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/faceswap_node.py
@@ -0,0 +1,254 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import IntEnum
+from typing import List, Union
+
+import cv2
+import numpy as np
+
+from mmpose.datasets import DatasetInfo
+from .builder import NODES
+from .frame_drawing_node import FrameDrawingNode
+
+
+class Mode(IntEnum):
+ NONE = 0,
+ SHUFFLE = 1,
+ CLONE = 2
+
+
+@NODES.register_module()
+class FaceSwapNode(FrameDrawingNode):
+
+ def __init__(
+ self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ mode_key: Union[str, int],
+ ):
+ super().__init__(name, frame_buffer, output_buffer, enable=True)
+
+ self.mode_key = mode_key
+ self.mode_index = 0
+ self.register_event(
+ self.mode_key, is_keyboard=True, handler_func=self.switch_mode)
+ self.history = dict(mode=None)
+ self._mode = Mode.SHUFFLE
+
+ @property
+ def mode(self):
+ return self._mode
+
+ def switch_mode(self):
+ """Switch modes by updating mode index."""
+ self._mode = Mode((self._mode + 1) % len(Mode))
+
+ def draw(self, frame_msg):
+
+ if self.mode == Mode.NONE:
+ self.history = {'mode': Mode.NONE}
+ return frame_msg.get_image()
+
+ # Init history
+ if self.history['mode'] != self.mode:
+ self.history = {'mode': self.mode, 'target_map': {}}
+
+ # Merge pose results
+ pose_preds = self._merge_pose_results(frame_msg.get_pose_results())
+ num_target = len(pose_preds)
+
+ # Show mode
+ img = frame_msg.get_image()
+ canvas = img.copy()
+ if self.mode == Mode.SHUFFLE:
+ mode_txt = 'Shuffle'
+ else:
+ mode_txt = 'Clone'
+
+ cv2.putText(canvas, mode_txt, (10, 50), cv2.FONT_HERSHEY_DUPLEX, 0.8,
+ (255, 126, 0), 1)
+
+ # Skip if target number is less than 2
+ if num_target >= 2:
+ # Generate new mapping if target number changes
+ if num_target != len(self.history['target_map']):
+ if self.mode == Mode.SHUFFLE:
+ self.history['target_map'] = self._get_swap_map(num_target)
+ else:
+ self.history['target_map'] = np.repeat(
+ np.random.choice(num_target), num_target)
+
+ # # Draw on canvas
+ for tar_idx, src_idx in enumerate(self.history['target_map']):
+ face_src = self._get_face_info(pose_preds[src_idx])
+ face_tar = self._get_face_info(pose_preds[tar_idx])
+ canvas = self._swap_face(img, canvas, face_src, face_tar)
+
+ return canvas
+
+ def _crop_face_by_contour(self, img, contour):
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
+ cv2.fillPoly(mask, [contour.astype(np.int32)], 1)
+ mask = cv2.dilate(
+ mask, kernel=np.ones((9, 9), dtype=np.uint8), anchor=(4, 0))
+ x1, y1, w, h = cv2.boundingRect(mask)
+ x2 = x1 + w
+ y2 = y1 + h
+ bbox = np.array([x1, y1, x2, y2], dtype=np.int64)
+ patch = img[y1:y2, x1:x2]
+ mask = mask[y1:y2, x1:x2]
+
+ return bbox, patch, mask
+
+ def _swap_face(self, img_src, img_tar, face_src, face_tar):
+
+ if face_src['dataset'] == face_tar['dataset']:
+ # Use full keypoints for face alignment
+ kpts_src = face_src['contour']
+ kpts_tar = face_tar['contour']
+ else:
+ # Use only common landmarks (eyes and nose) for face alignment if
+ # source and target have differenet data type
+ # (e.g. human vs animal)
+ kpts_src = face_src['landmarks']
+ kpts_tar = face_tar['landmarks']
+
+ # Get everything local
+ bbox_src, patch_src, mask_src = self._crop_face_by_contour(
+ img_src, face_src['contour'])
+
+ bbox_tar, _, mask_tar = self._crop_face_by_contour(
+ img_tar, face_tar['contour'])
+
+ kpts_src = kpts_src - bbox_src[:2]
+ kpts_tar = kpts_tar - bbox_tar[:2]
+
+ # Compute affine transformation matrix
+ trans_mat, _ = cv2.estimateAffine2D(
+ kpts_src.astype(np.float32), kpts_tar.astype(np.float32))
+ patch_warp = cv2.warpAffine(
+ patch_src,
+ trans_mat,
+ dsize=tuple(bbox_tar[2:] - bbox_tar[:2]),
+ borderValue=(0, 0, 0))
+ mask_warp = cv2.warpAffine(
+ mask_src,
+ trans_mat,
+ dsize=tuple(bbox_tar[2:] - bbox_tar[:2]),
+ borderValue=(0, 0, 0))
+
+ # Target mask
+ mask_tar = mask_tar & mask_warp
+ mask_tar_soft = cv2.GaussianBlur(mask_tar * 255, (3, 3), 3)
+
+ # Blending
+ center = tuple((0.5 * (bbox_tar[:2] + bbox_tar[2:])).astype(np.int64))
+ img_tar = cv2.seamlessClone(patch_warp, img_tar, mask_tar_soft, center,
+ cv2.NORMAL_CLONE)
+ return img_tar
+
+ @staticmethod
+ def _get_face_info(pose_pred):
+ keypoints = pose_pred['keypoints'][:, :2]
+ model_cfg = pose_pred['model_cfg']
+ dataset_info = DatasetInfo(model_cfg.data.test.dataset_info)
+
+ face_info = {
+ 'dataset': dataset_info.dataset_name,
+ 'landmarks': None, # For alignment
+ 'contour': None, # For mask generation
+ 'bbox': None # For image warping
+ }
+
+ # Fall back to hard coded keypoint id
+
+ if face_info['dataset'] == 'coco':
+ face_info['landmarks'] = np.stack([
+ keypoints[1], # left eye
+ keypoints[2], # right eye
+ keypoints[0], # nose
+ 0.5 * (keypoints[5] + keypoints[6]), # neck (shoulder center)
+ ])
+ elif face_info['dataset'] == 'coco_wholebody':
+ face_info['landmarks'] = np.stack([
+ keypoints[1], # left eye
+ keypoints[2], # right eye
+ keypoints[0], # nose
+ keypoints[32], # chin
+ ])
+ contour_ids = list(range(23, 40)) + list(range(40, 50))[::-1]
+ face_info['contour'] = keypoints[contour_ids]
+ elif face_info['dataset'] == 'ap10k':
+ face_info['landmarks'] = np.stack([
+ keypoints[0], # left eye
+ keypoints[1], # right eye
+ keypoints[2], # nose
+ keypoints[3], # neck
+ ])
+ elif face_info['dataset'] == 'animalpose':
+ face_info['landmarks'] = np.stack([
+ keypoints[0], # left eye
+ keypoints[1], # right eye
+ keypoints[4], # nose
+ keypoints[5], # throat
+ ])
+ elif face_info['dataset'] == 'wflw':
+ face_info['landmarks'] = np.stack([
+ keypoints[97], # left eye
+ keypoints[96], # right eye
+ keypoints[54], # nose
+ keypoints[16], # chine
+ ])
+ contour_ids = list(range(33))[::-1] + list(range(33, 38)) + list(
+ range(42, 47))
+ face_info['contour'] = keypoints[contour_ids]
+ else:
+ raise ValueError('Can not obtain face landmark information'
+ f'from dataset: {face_info["type"]}')
+
+ # Face region
+ if face_info['contour'] is None:
+ # Manually defined counter of face region
+ left_eye, right_eye, nose = face_info['landmarks'][:3]
+ eye_center = 0.5 * (left_eye + right_eye)
+ w_vec = right_eye - left_eye
+ eye_dist = np.linalg.norm(w_vec) + 1e-6
+ w_vec = w_vec / eye_dist
+ h_vec = np.array([w_vec[1], -w_vec[0]], dtype=w_vec.dtype)
+ w = max(0.5 * eye_dist, np.abs(np.dot(nose - eye_center, w_vec)))
+ h = np.abs(np.dot(nose - eye_center, h_vec))
+
+ left_top = eye_center + 1.5 * w * w_vec - 0.5 * h * h_vec
+ right_top = eye_center - 1.5 * w * w_vec - 0.5 * h * h_vec
+ left_bottom = eye_center + 1.5 * w * w_vec + 4 * h * h_vec
+ right_bottom = eye_center - 1.5 * w * w_vec + 4 * h * h_vec
+
+ face_info['contour'] = np.stack(
+ [left_top, right_top, right_bottom, left_bottom])
+
+ # Get tight bbox of face region
+ face_info['bbox'] = np.array([
+ face_info['contour'][:, 0].min(), face_info['contour'][:, 1].min(),
+ face_info['contour'][:, 0].max(), face_info['contour'][:, 1].max()
+ ]).astype(np.int64)
+
+ return face_info
+
+ @staticmethod
+ def _merge_pose_results(pose_results):
+ preds = []
+ if pose_results is not None:
+ for prefix, pose_result in enumerate(pose_results):
+ model_cfg = pose_result['model_cfg']
+ for idx, _pred in enumerate(pose_result['preds']):
+ pred = _pred.copy()
+ pred['id'] = f'{prefix}.{_pred.get("track_id", str(idx))}'
+ pred['model_cfg'] = model_cfg
+ preds.append(pred)
+ return preds
+
+ @staticmethod
+ def _get_swap_map(num_target):
+ ids = np.random.choice(num_target, num_target, replace=False)
+ target_map = ids[(ids + 1) % num_target]
+ return target_map
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/frame_drawing_node.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/frame_drawing_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfc3511cadc2e8db0fb393ba1f821ee8091fcada
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/frame_drawing_node.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ..utils import FrameMessage, Message
+from .node import Node
+
+
+class FrameDrawingNode(Node):
+ """Base class for Node that draw on single frame images.
+
+ Args:
+ name (str, optional): The node name (also thread name).
+ frame_buffer (str): The name of the input buffer.
+ output_buffer (str | list): The name(s) of the output buffer(s).
+ enable_key (str | int, optional): Set a hot-key to toggle
+ enable/disable of the node. If an int value is given, it will be
+ treated as an ascii code of a key. Please note:
+ 1. If enable_key is set, the bypass method need to be
+ overridden to define the node behavior when disabled
+ 2. Some hot-key has been use for particular use. For example:
+ 'q', 'Q' and 27 are used for quit
+ Default: None
+ enable (bool): Default enable/disable status. Default: True.
+ """
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True):
+
+ super().__init__(name=name, enable_key=enable_key)
+
+ # Register buffers
+ self.register_input_buffer(frame_buffer, 'frame', essential=True)
+ self.register_output_buffer(output_buffer)
+
+ self._enabled = enable
+
+ def process(self, input_msgs: Dict[str, Message]) -> Union[Message, None]:
+ frame_msg = input_msgs['frame']
+
+ img = self.draw(frame_msg)
+ frame_msg.set_image(img)
+
+ return frame_msg
+
+ def bypass(self, input_msgs: Dict[str, Message]) -> Union[Message, None]:
+ return input_msgs['frame']
+
+ @abstractmethod
+ def draw(self, frame_msg: FrameMessage) -> np.ndarray:
+ """Draw on the frame image with information from the single frame.
+
+ Args:
+ frame_meg (FrameMessage): The frame to get information from and
+ draw on.
+
+ Returns:
+ array: The output image
+ """
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/frame_effect_node.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/frame_effect_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..c248c3820a944e6b5e7f0613794d6290fcda7bcc
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/frame_effect_node.py
@@ -0,0 +1,917 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+from mmcv import color_val
+
+from mmpose.core import (apply_bugeye_effect, apply_sunglasses_effect,
+ imshow_bboxes, imshow_keypoints)
+from mmpose.datasets import DatasetInfo
+from ..utils import (FrameMessage, copy_and_paste, expand_and_clamp,
+ get_cached_file_path, get_eye_keypoint_ids,
+ get_face_keypoint_ids, get_wrist_keypoint_ids,
+ load_image_from_disk_or_url, screen_matting)
+from .builder import NODES
+from .frame_drawing_node import FrameDrawingNode
+
+try:
+ import psutil
+ psutil_proc = psutil.Process()
+except (ImportError, ModuleNotFoundError):
+ psutil_proc = None
+
+
+@NODES.register_module()
+class PoseVisualizerNode(FrameDrawingNode):
+ """Draw the bbox and keypoint detection results.
+
+ Args:
+ name (str, optional): The node name (also thread name).
+ frame_buffer (str): The name of the input buffer.
+ output_buffer (str|list): The name(s) of the output buffer(s).
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note:
+ 1. If enable_key is set, the bypass method need to be
+ overridden to define the node behavior when disabled
+ 2. Some hot-key has been use for particular use. For example:
+ 'q', 'Q' and 27 are used for quit
+ Default: None
+ enable (bool): Default enable/disable status. Default: True.
+ kpt_thr (float): The threshold of keypoint score. Default: 0.3.
+ radius (int): The radius of keypoint. Default: 4.
+ thickness (int): The thickness of skeleton. Default: 2.
+ bbox_color (str|tuple|dict): If a single color (a str like 'green' or
+ a tuple like (0, 255, 0)), it will used to draw the bbox.
+ Optionally, a dict can be given as a map from class labels to
+ colors.
+ """
+
+ default_bbox_color = {
+ 'person': (148, 139, 255),
+ 'cat': (255, 255, 0),
+ 'dog': (255, 255, 0),
+ }
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ kpt_thr: float = 0.3,
+ radius: int = 4,
+ thickness: int = 2,
+ bbox_color: Optional[Union[str, Tuple, Dict]] = None):
+
+ super().__init__(name, frame_buffer, output_buffer, enable_key, enable)
+
+ self.kpt_thr = kpt_thr
+ self.radius = radius
+ self.thickness = thickness
+ if bbox_color is None:
+ self.bbox_color = self.default_bbox_color
+ elif isinstance(bbox_color, dict):
+ self.bbox_color = {k: color_val(v) for k, v in bbox_color.items()}
+ else:
+ self.bbox_color = color_val(bbox_color)
+
+ def draw(self, frame_msg):
+ canvas = frame_msg.get_image()
+ pose_results = frame_msg.get_pose_results()
+
+ if not pose_results:
+ return canvas
+
+ for pose_result in frame_msg.get_pose_results():
+ model_cfg = pose_result['model_cfg']
+ dataset_info = DatasetInfo(model_cfg.dataset_info)
+
+ # Extract bboxes and poses
+ bbox_preds = []
+ bbox_labels = []
+ pose_preds = []
+ for pred in pose_result['preds']:
+ if 'bbox' in pred:
+ bbox_preds.append(pred['bbox'])
+ bbox_labels.append(pred.get('label', None))
+ pose_preds.append(pred['keypoints'])
+
+ # Get bbox colors
+ if isinstance(self.bbox_color, dict):
+ bbox_colors = [
+ self.bbox_color.get(label, (0, 255, 0))
+ for label in bbox_labels
+ ]
+ else:
+ bbox_labels = self.bbox_color
+
+ # Draw bboxes
+ if bbox_preds:
+ bboxes = np.vstack(bbox_preds)
+
+ imshow_bboxes(
+ canvas,
+ bboxes,
+ labels=bbox_labels,
+ colors=bbox_colors,
+ text_color='white',
+ font_scale=0.5,
+ show=False)
+
+ # Draw poses
+ if pose_preds:
+ imshow_keypoints(
+ canvas,
+ pose_preds,
+ skeleton=dataset_info.skeleton,
+ kpt_score_thr=0.3,
+ pose_kpt_color=dataset_info.pose_kpt_color,
+ pose_link_color=dataset_info.pose_link_color,
+ radius=self.radius,
+ thickness=self.thickness)
+
+ return canvas
+
+
+@NODES.register_module()
+class SunglassesNode(FrameDrawingNode):
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ src_img_path: Optional[str] = None):
+
+ super().__init__(name, frame_buffer, output_buffer, enable_key, enable)
+
+ if src_img_path is None:
+ # The image attributes to:
+ # https://www.vecteezy.com/free-vector/glass
+ # Glass Vectors by Vecteezy
+ src_img_path = 'demo/resources/sunglasses.jpg'
+ self.src_img = load_image_from_disk_or_url(src_img_path)
+
+ def draw(self, frame_msg):
+ canvas = frame_msg.get_image()
+ pose_results = frame_msg.get_pose_results()
+ if not pose_results:
+ return canvas
+ for pose_result in pose_results:
+ model_cfg = pose_result['model_cfg']
+ preds = pose_result['preds']
+ left_eye_idx, right_eye_idx = get_eye_keypoint_ids(model_cfg)
+
+ canvas = apply_sunglasses_effect(canvas, preds, self.src_img,
+ left_eye_idx, right_eye_idx)
+ return canvas
+
+
+@NODES.register_module()
+class SpriteNode(FrameDrawingNode):
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ src_img_path: Optional[str] = None):
+
+ super().__init__(name, frame_buffer, output_buffer, enable_key, enable)
+
+ if src_img_path is None:
+ # Sprites of Touhou characters :)
+ # Come from https://www.deviantart.com/shadowbendy/art/Touhou-rpg-maker-vx-Sprite-1-812746920 # noqa: E501
+ src_img_path = (
+ 'https://user-images.githubusercontent.com/'
+ '26739999/151532276-33f968d9-917f-45e3-8a99-ebde60be83bb.png')
+ self.src_img = load_image_from_disk_or_url(
+ src_img_path, cv2.IMREAD_UNCHANGED)[:144, :108]
+ tmp = np.array(np.split(self.src_img, range(36, 144, 36), axis=0))
+ tmp = np.array(np.split(tmp, range(36, 108, 36), axis=2))
+ self.sprites = tmp
+ self.pos = None
+ self.anime_frame = 0
+
+ def apply_sprite_effect(self,
+ img,
+ pose_results,
+ left_hand_index,
+ right_hand_index,
+ kpt_thr=0.5):
+ """Apply sprite effect.
+
+ Args:
+ img (np.ndarray): Image data.
+ pose_results (list[dict]): The pose estimation results containing:
+ - "keypoints" ([K,3]): detection result in [x, y, score]
+ left_hand_index (int): Keypoint index of left hand
+ right_hand_index (int): Keypoint index of right hand
+ kpt_thr (float): The score threshold of required keypoints.
+ """
+
+ hm, wm = self.sprites.shape[2:4]
+ # anchor points in the sunglasses mask
+ if self.pos is None:
+ self.pos = [img.shape[0] // 2, img.shape[1] // 2]
+
+ if len(pose_results) == 0:
+ return img
+
+ kpts = pose_results[0]['keypoints']
+
+ if kpts[left_hand_index, 2] < kpt_thr and kpts[right_hand_index,
+ 2] < kpt_thr:
+ aim = self.pos
+ else:
+ kpt_lhand = kpts[left_hand_index, :2][::-1]
+ kpt_rhand = kpts[right_hand_index, :2][::-1]
+
+ def distance(a, b):
+ return (a[0] - b[0])**2 + (a[1] - b[1])**2
+
+ # Go to the nearest hand
+ if distance(kpt_lhand, self.pos) < distance(kpt_rhand, self.pos):
+ aim = kpt_lhand
+ else:
+ aim = kpt_rhand
+
+ pos_thr = 15
+ if aim[0] < self.pos[0] - pos_thr:
+ # Go down
+ sprite = self.sprites[self.anime_frame][3]
+ self.pos[0] -= 1
+ elif aim[0] > self.pos[0] + pos_thr:
+ # Go up
+ sprite = self.sprites[self.anime_frame][0]
+ self.pos[0] += 1
+ elif aim[1] < self.pos[1] - pos_thr:
+ # Go right
+ sprite = self.sprites[self.anime_frame][1]
+ self.pos[1] -= 1
+ elif aim[1] > self.pos[1] + pos_thr:
+ # Go left
+ sprite = self.sprites[self.anime_frame][2]
+ self.pos[1] += 1
+ else:
+ # Stay
+ self.anime_frame = 0
+ sprite = self.sprites[self.anime_frame][0]
+
+ if self.anime_frame < 2:
+ self.anime_frame += 1
+ else:
+ self.anime_frame = 0
+
+ x = self.pos[0] - hm // 2
+ y = self.pos[1] - wm // 2
+ x = max(0, min(x, img.shape[0] - hm))
+ y = max(0, min(y, img.shape[0] - wm))
+
+ # Overlay image with transparent
+ img[x:x + hm, y:y +
+ wm] = (img[x:x + hm, y:y + wm] * (1 - sprite[:, :, 3:] / 255) +
+ sprite[:, :, :3] * (sprite[:, :, 3:] / 255)).astype('uint8')
+
+ return img
+
+ def draw(self, frame_msg):
+ canvas = frame_msg.get_image()
+ pose_results = frame_msg.get_pose_results()
+ if not pose_results:
+ return canvas
+ for pose_result in pose_results:
+ model_cfg = pose_result['model_cfg']
+ preds = pose_result['preds']
+ # left_hand_idx, right_hand_idx = get_wrist_keypoint_ids(model_cfg) # noqa: E501
+ left_hand_idx, right_hand_idx = get_eye_keypoint_ids(model_cfg)
+
+ canvas = self.apply_sprite_effect(canvas, preds, left_hand_idx,
+ right_hand_idx)
+ return canvas
+
+
+@NODES.register_module()
+class BackgroundNode(FrameDrawingNode):
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ src_img_path: Optional[str] = None,
+ cls_ids: Optional[List] = None,
+ cls_names: Optional[List] = None):
+
+ super().__init__(name, frame_buffer, output_buffer, enable_key, enable)
+
+ self.cls_ids = cls_ids
+ self.cls_names = cls_names
+
+ if src_img_path is None:
+ src_img_path = 'https://user-images.githubusercontent.com/'\
+ '11788150/149731957-abd5c908-9c7f-45b2-b7bf-'\
+ '821ab30c6a3e.jpg'
+ self.src_img = load_image_from_disk_or_url(src_img_path)
+
+ def apply_background_effect(self,
+ img,
+ det_results,
+ background_img,
+ effect_region=(0.2, 0.2, 0.8, 0.8)):
+ """Change background.
+
+ Args:
+ img (np.ndarray): Image data.
+ det_results (list[dict]): The detection results containing:
+
+ - "cls_id" (int): Class index.
+ - "label" (str): Class label (e.g. 'person').
+ - "bbox" (ndarray:(5, )): bounding box result
+ [x, y, w, h, score].
+ - "mask" (ndarray:(w, h)): instance segmentation result.
+ background_img (np.ndarray): Background image.
+ effect_region (tuple(4, )): The region to apply mask,
+ the coordinates are normalized (x1, y1, x2, y2).
+ """
+ if len(det_results) > 0:
+ # Choose the one with the highest score.
+ det_result = det_results[0]
+ bbox = det_result['bbox']
+ mask = det_result['mask'].astype(np.uint8)
+ img = copy_and_paste(img, background_img, mask, bbox,
+ effect_region)
+ return img
+ else:
+ return background_img
+
+ def draw(self, frame_msg):
+ canvas = frame_msg.get_image()
+ if canvas.shape != self.src_img.shape:
+ self.src_img = cv2.resize(self.src_img, canvas.shape[:2])
+ det_results = frame_msg.get_detection_results()
+ if not det_results:
+ return canvas
+
+ full_preds = []
+ for det_result in det_results:
+ preds = det_result['preds']
+ if self.cls_ids:
+ # Filter results by class ID
+ filtered_preds = [
+ p for p in preds if p['cls_id'] in self.cls_ids
+ ]
+ elif self.cls_names:
+ # Filter results by class name
+ filtered_preds = [
+ p for p in preds if p['label'] in self.cls_names
+ ]
+ else:
+ filtered_preds = preds
+ full_preds.extend(filtered_preds)
+
+ canvas = self.apply_background_effect(canvas, full_preds, self.src_img)
+
+ return canvas
+
+
+@NODES.register_module()
+class SaiyanNode(FrameDrawingNode):
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ hair_img_path: Optional[str] = None,
+ light_video_path: Optional[str] = None,
+ cls_ids: Optional[List] = None,
+ cls_names: Optional[List] = None):
+
+ super().__init__(name, frame_buffer, output_buffer, enable_key, enable)
+
+ self.cls_ids = cls_ids
+ self.cls_names = cls_names
+
+ if hair_img_path is None:
+ hair_img_path = 'https://user-images.githubusercontent.com/'\
+ '11788150/149732117-fcd2d804-dc2c-426c-bee7-'\
+ '94be6146e05c.png'
+ self.hair_img = load_image_from_disk_or_url(hair_img_path)
+
+ if light_video_path is None:
+ light_video_path = get_cached_file_path(
+ 'https://'
+ 'user-images.githubusercontent.com/11788150/149732080'
+ '-ea6cfeda-0dc5-4bbb-892a-3831e5580520.mp4')
+ self.light_video_path = light_video_path
+ self.light_video = cv2.VideoCapture(self.light_video_path)
+
+ def apply_saiyan_effect(self,
+ img,
+ pose_results,
+ saiyan_img,
+ light_frame,
+ face_indices,
+ bbox_thr=0.3,
+ kpt_thr=0.5):
+ """Apply saiyan hair effect.
+
+ Args:
+ img (np.ndarray): Image data.
+ pose_results (list[dict]): The pose estimation results containing:
+ - "keypoints" ([K,3]): keypoint detection result
+ in [x, y, score]
+ saiyan_img (np.ndarray): Saiyan image with transparent background.
+ light_frame (np.ndarray): Light image with green screen.
+ face_indices (int): Keypoint index of the face
+ kpt_thr (float): The score threshold of required keypoints.
+ """
+ img = img.copy()
+ im_shape = img.shape
+ # Apply lightning effects.
+ light_mask = screen_matting(light_frame, color='green')
+
+ # anchor points in the mask
+ pts_src = np.array(
+ [
+ [84, 398], # face kpt 0
+ [331, 393], # face kpt 16
+ [84, 145],
+ [331, 140]
+ ],
+ dtype=np.float32)
+
+ for pose in pose_results:
+ bbox = pose['bbox']
+
+ if bbox[-1] < bbox_thr:
+ continue
+
+ mask_inst = pose['mask']
+ # cache
+ fg = img[np.where(mask_inst)]
+
+ bbox = expand_and_clamp(bbox[:4], im_shape, s=3.0)
+ # Apply light effects between fg and bg
+ img = copy_and_paste(
+ light_frame,
+ img,
+ light_mask,
+ effect_region=(bbox[0] / im_shape[1], bbox[1] / im_shape[0],
+ bbox[2] / im_shape[1], bbox[3] / im_shape[0]))
+ # pop
+ img[np.where(mask_inst)] = fg
+
+ # Apply Saiyan hair effects
+ kpts = pose['keypoints']
+ if kpts[face_indices[0], 2] < kpt_thr or kpts[face_indices[16],
+ 2] < kpt_thr:
+ continue
+
+ kpt_0 = kpts[face_indices[0], :2]
+ kpt_16 = kpts[face_indices[16], :2]
+ # orthogonal vector
+ vo = (kpt_0 - kpt_16)[::-1] * [-1, 1]
+
+ # anchor points in the image by eye positions
+ pts_tar = np.vstack([kpt_0, kpt_16, kpt_0 + vo, kpt_16 + vo])
+
+ h_mat, _ = cv2.findHomography(pts_src, pts_tar)
+ patch = cv2.warpPerspective(
+ saiyan_img,
+ h_mat,
+ dsize=(img.shape[1], img.shape[0]),
+ borderValue=(0, 0, 0))
+ mask_patch = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
+ mask_patch = (mask_patch > 1).astype(np.uint8)
+ img = cv2.copyTo(patch, mask_patch, img)
+
+ return img
+
+ def draw(self, frame_msg):
+ canvas = frame_msg.get_image()
+
+ det_results = frame_msg.get_detection_results()
+ if not det_results:
+ return canvas
+
+ pose_results = frame_msg.get_pose_results()
+ if not pose_results:
+ return canvas
+
+ for pose_result in pose_results:
+ model_cfg = pose_result['model_cfg']
+ preds = pose_result['preds']
+ face_indices = get_face_keypoint_ids(model_cfg)
+
+ ret, frame = self.light_video.read()
+ if not ret:
+ self.light_video = cv2.VideoCapture(self.light_video_path)
+ ret, frame = self.light_video.read()
+
+ canvas = self.apply_saiyan_effect(canvas, preds, self.hair_img,
+ frame, face_indices)
+
+ return canvas
+
+
+@NODES.register_module()
+class MoustacheNode(FrameDrawingNode):
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ src_img_path: Optional[str] = None):
+
+ super().__init__(name, frame_buffer, output_buffer, enable_key, enable)
+
+ if src_img_path is None:
+ src_img_path = 'https://user-images.githubusercontent.com/'\
+ '11788150/149732141-3afbab55-252a-428c-b6d8'\
+ '-0e352f432651.jpeg'
+ self.src_img = load_image_from_disk_or_url(src_img_path)
+
+ def apply_moustache_effect(self,
+ img,
+ pose_results,
+ moustache_img,
+ face_indices,
+ kpt_thr=0.5):
+ """Apply moustache effect.
+
+ Args:
+ img (np.ndarray): Image data.
+ pose_results (list[dict]): The pose estimation results containing:
+ - "keypoints" ([K,3]): keypoint detection result
+ in [x, y, score]
+ moustache_img (np.ndarray): Moustache image with white background.
+ left_eye_index (int): Keypoint index of left eye
+ right_eye_index (int): Keypoint index of right eye
+ kpt_thr (float): The score threshold of required keypoints.
+ """
+
+ hm, wm = moustache_img.shape[:2]
+ # anchor points in the moustache mask
+ pts_src = np.array([[1164, 741], [1729, 741], [1164, 1244],
+ [1729, 1244]],
+ dtype=np.float32)
+
+ for pose in pose_results:
+ kpts = pose['keypoints']
+ if kpts[face_indices[32], 2] < kpt_thr \
+ or kpts[face_indices[34], 2] < kpt_thr \
+ or kpts[face_indices[61], 2] < kpt_thr \
+ or kpts[face_indices[63], 2] < kpt_thr:
+ continue
+
+ kpt_32 = kpts[face_indices[32], :2]
+ kpt_34 = kpts[face_indices[34], :2]
+ kpt_61 = kpts[face_indices[61], :2]
+ kpt_63 = kpts[face_indices[63], :2]
+ # anchor points in the image by eye positions
+ pts_tar = np.vstack([kpt_32, kpt_34, kpt_61, kpt_63])
+
+ h_mat, _ = cv2.findHomography(pts_src, pts_tar)
+ patch = cv2.warpPerspective(
+ moustache_img,
+ h_mat,
+ dsize=(img.shape[1], img.shape[0]),
+ borderValue=(255, 255, 255))
+ # mask the white background area in the patch with a threshold 200
+ mask = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
+ mask = (mask < 200).astype(np.uint8)
+ img = cv2.copyTo(patch, mask, img)
+
+ return img
+
+ def draw(self, frame_msg):
+ canvas = frame_msg.get_image()
+ pose_results = frame_msg.get_pose_results()
+ if not pose_results:
+ return canvas
+ for pose_result in pose_results:
+ model_cfg = pose_result['model_cfg']
+ preds = pose_result['preds']
+ face_indices = get_face_keypoint_ids(model_cfg)
+ canvas = self.apply_moustache_effect(canvas, preds, self.src_img,
+ face_indices)
+ return canvas
+
+
+@NODES.register_module()
+class BugEyeNode(FrameDrawingNode):
+
+ def draw(self, frame_msg):
+ canvas = frame_msg.get_image()
+ pose_results = frame_msg.get_pose_results()
+ if not pose_results:
+ return canvas
+ for pose_result in pose_results:
+ model_cfg = pose_result['model_cfg']
+ preds = pose_result['preds']
+ left_eye_idx, right_eye_idx = get_eye_keypoint_ids(model_cfg)
+
+ canvas = apply_bugeye_effect(canvas, preds, left_eye_idx,
+ right_eye_idx)
+ return canvas
+
+
+@NODES.register_module()
+class NoticeBoardNode(FrameDrawingNode):
+
+ default_content_lines = ['This is a notice board!']
+
+ def __init__(
+ self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ content_lines: Optional[List[str]] = None,
+ x_offset: int = 20,
+ y_offset: int = 20,
+ y_delta: int = 15,
+ text_color: Union[str, Tuple[int, int, int]] = 'black',
+ background_color: Union[str, Tuple[int, int, int]] = (255, 183, 0),
+ text_scale: float = 0.4,
+ ):
+ super().__init__(name, frame_buffer, output_buffer, enable_key, enable)
+
+ self.x_offset = x_offset
+ self.y_offset = y_offset
+ self.y_delta = y_delta
+ self.text_color = color_val(text_color)
+ self.background_color = color_val(background_color)
+ self.text_scale = text_scale
+
+ if content_lines:
+ self.content_lines = content_lines
+ else:
+ self.content_lines = self.default_content_lines
+
+ def draw(self, frame_msg: FrameMessage) -> np.ndarray:
+ img = frame_msg.get_image()
+ canvas = np.full(img.shape, self.background_color, dtype=img.dtype)
+
+ x = self.x_offset
+ y = self.y_offset
+
+ max_len = max([len(line) for line in self.content_lines])
+
+ def _put_line(line=''):
+ nonlocal y
+ cv2.putText(canvas, line, (x, y), cv2.FONT_HERSHEY_DUPLEX,
+ self.text_scale, self.text_color, 1)
+ y += self.y_delta
+
+ for line in self.content_lines:
+ _put_line(line)
+
+ x1 = max(0, self.x_offset)
+ x2 = min(img.shape[1], int(x + max_len * self.text_scale * 20))
+ y1 = max(0, self.y_offset - self.y_delta)
+ y2 = min(img.shape[0], y)
+
+ src1 = canvas[y1:y2, x1:x2]
+ src2 = img[y1:y2, x1:x2]
+ img[y1:y2, x1:x2] = cv2.addWeighted(src1, 0.5, src2, 0.5, 0)
+
+ return img
+
+
+@NODES.register_module()
+class HatNode(FrameDrawingNode):
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ src_img_path: Optional[str] = None):
+
+ super().__init__(name, frame_buffer, output_buffer, enable_key)
+
+ if src_img_path is None:
+ # The image attributes to:
+ # http://616pic.com/sucai/1m9i70p52.html
+ src_img_path = 'https://user-images.githubusercontent.' \
+ 'com/28900607/149766271-2f591c19-9b67-4' \
+ 'd92-8f94-c272396ca141.png'
+ self.src_img = load_image_from_disk_or_url(src_img_path,
+ cv2.IMREAD_UNCHANGED)
+
+ @staticmethod
+ def apply_hat_effect(img,
+ pose_results,
+ hat_img,
+ left_eye_index,
+ right_eye_index,
+ kpt_thr=0.5):
+ """Apply hat effect.
+ Args:
+ img (np.ndarray): Image data.
+ pose_results (list[dict]): The pose estimation results containing:
+ - "keypoints" ([K,3]): keypoint detection result in
+ [x, y, score]
+ hat_img (np.ndarray): Hat image with white alpha channel.
+ left_eye_index (int): Keypoint index of left eye
+ right_eye_index (int): Keypoint index of right eye
+ kpt_thr (float): The score threshold of required keypoints.
+ """
+ img_orig = img.copy()
+
+ img = img_orig.copy()
+ hm, wm = hat_img.shape[:2]
+ # anchor points in the sunglasses mask
+ a = 0.3
+ b = 0.7
+ pts_src = np.array([[a * wm, a * hm], [a * wm, b * hm],
+ [b * wm, a * hm], [b * wm, b * hm]],
+ dtype=np.float32)
+
+ for pose in pose_results:
+ kpts = pose['keypoints']
+
+ if kpts[left_eye_index, 2] < kpt_thr or \
+ kpts[right_eye_index, 2] < kpt_thr:
+ continue
+
+ kpt_leye = kpts[left_eye_index, :2]
+ kpt_reye = kpts[right_eye_index, :2]
+ # orthogonal vector to the left-to-right eyes
+ vo = 0.5 * (kpt_reye - kpt_leye)[::-1] * [-1, 1]
+ veye = 0.5 * (kpt_reye - kpt_leye)
+
+ # anchor points in the image by eye positions
+ pts_tar = np.vstack([
+ kpt_reye + 1 * veye + 5 * vo, kpt_reye + 1 * veye + 1 * vo,
+ kpt_leye - 1 * veye + 5 * vo, kpt_leye - 1 * veye + 1 * vo
+ ])
+
+ h_mat, _ = cv2.findHomography(pts_src, pts_tar)
+ patch = cv2.warpPerspective(
+ hat_img,
+ h_mat,
+ dsize=(img.shape[1], img.shape[0]),
+ borderValue=(255, 255, 255))
+ # mask the white background area in the patch with a threshold 200
+ mask = (patch[:, :, -1] > 128)
+ patch = patch[:, :, :-1]
+ mask = mask * (cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY) > 30)
+ mask = mask.astype(np.uint8)
+
+ img = cv2.copyTo(patch, mask, img)
+ return img
+
+ def draw(self, frame_msg):
+ canvas = frame_msg.get_image()
+ pose_results = frame_msg.get_pose_results()
+ if not pose_results:
+ return canvas
+ for pose_result in pose_results:
+ model_cfg = pose_result['model_cfg']
+ preds = pose_result['preds']
+ left_eye_idx, right_eye_idx = get_eye_keypoint_ids(model_cfg)
+
+ canvas = self.apply_hat_effect(canvas, preds, self.src_img,
+ left_eye_idx, right_eye_idx)
+ return canvas
+
+
+@NODES.register_module()
+class FirecrackerNode(FrameDrawingNode):
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ src_img_path: Optional[str] = None):
+
+ super().__init__(name, frame_buffer, output_buffer, enable_key)
+
+ if src_img_path is None:
+ self.src_img_path = 'https://user-images.githubusercontent' \
+ '.com/28900607/149766281-6376055c-ed8b' \
+ '-472b-991f-60e6ae6ee1da.gif'
+ src_img = cv2.VideoCapture(self.src_img_path)
+
+ self.frame_list = []
+ ret, frame = src_img.read()
+ while frame is not None:
+ self.frame_list.append(frame)
+ ret, frame = src_img.read()
+ self.num_frames = len(self.frame_list)
+ self.frame_idx = 0
+ self.frame_period = 4 # each frame in gif lasts for 4 frames in video
+
+ @staticmethod
+ def apply_firecracker_effect(img,
+ pose_results,
+ firecracker_img,
+ left_wrist_idx,
+ right_wrist_idx,
+ kpt_thr=0.5):
+ """Apply firecracker effect.
+ Args:
+ img (np.ndarray): Image data.
+ pose_results (list[dict]): The pose estimation results containing:
+ - "keypoints" ([K,3]): keypoint detection result in
+ [x, y, score]
+ firecracker_img (np.ndarray): Firecracker image with white
+ background.
+ left_wrist_idx (int): Keypoint index of left wrist
+ right_wrist_idx (int): Keypoint index of right wrist
+ kpt_thr (float): The score threshold of required keypoints.
+ """
+
+ hm, wm = firecracker_img.shape[:2]
+ # anchor points in the firecracker mask
+ pts_src = np.array([[0. * wm, 0. * hm], [0. * wm, 1. * hm],
+ [1. * wm, 0. * hm], [1. * wm, 1. * hm]],
+ dtype=np.float32)
+
+ h, w = img.shape[:2]
+ h_tar = h / 3
+ w_tar = h_tar / hm * wm
+
+ for pose in pose_results:
+ kpts = pose['keypoints']
+
+ if kpts[left_wrist_idx, 2] > kpt_thr:
+ kpt_lwrist = kpts[left_wrist_idx, :2]
+ # anchor points in the image by eye positions
+ pts_tar = np.vstack([
+ kpt_lwrist - [w_tar / 2, 0],
+ kpt_lwrist - [w_tar / 2, -h_tar],
+ kpt_lwrist + [w_tar / 2, 0],
+ kpt_lwrist + [w_tar / 2, h_tar]
+ ])
+
+ h_mat, _ = cv2.findHomography(pts_src, pts_tar)
+ patch = cv2.warpPerspective(
+ firecracker_img,
+ h_mat,
+ dsize=(img.shape[1], img.shape[0]),
+ borderValue=(255, 255, 255))
+ # mask the white background area in the patch with
+ # a threshold 200
+ mask = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
+ mask = (mask < 240).astype(np.uint8)
+ img = cv2.copyTo(patch, mask, img)
+
+ if kpts[right_wrist_idx, 2] > kpt_thr:
+ kpt_rwrist = kpts[right_wrist_idx, :2]
+
+ # anchor points in the image by eye positions
+ pts_tar = np.vstack([
+ kpt_rwrist - [w_tar / 2, 0],
+ kpt_rwrist - [w_tar / 2, -h_tar],
+ kpt_rwrist + [w_tar / 2, 0],
+ kpt_rwrist + [w_tar / 2, h_tar]
+ ])
+
+ h_mat, _ = cv2.findHomography(pts_src, pts_tar)
+ patch = cv2.warpPerspective(
+ firecracker_img,
+ h_mat,
+ dsize=(img.shape[1], img.shape[0]),
+ borderValue=(255, 255, 255))
+ # mask the white background area in the patch with
+ # a threshold 200
+ mask = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
+ mask = (mask < 240).astype(np.uint8)
+ img = cv2.copyTo(patch, mask, img)
+
+ return img
+
+ def draw(self, frame_msg):
+ canvas = frame_msg.get_image()
+ pose_results = frame_msg.get_pose_results()
+ if not pose_results:
+ return canvas
+
+ frame = self.frame_list[self.frame_idx // self.frame_period]
+ for pose_result in pose_results:
+ model_cfg = pose_result['model_cfg']
+ preds = pose_result['preds']
+ left_wrist_idx, right_wrist_idx = get_wrist_keypoint_ids(model_cfg)
+
+ canvas = self.apply_firecracker_effect(canvas, preds, frame,
+ left_wrist_idx,
+ right_wrist_idx)
+ self.frame_idx = (self.frame_idx + 1) % (
+ self.num_frames * self.frame_period)
+
+ return canvas
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/helper_node.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/helper_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..349c4f423456781a092d83fc6382d7f9f3376fd8
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/helper_node.py
@@ -0,0 +1,296 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import time
+from queue import Full, Queue
+from threading import Thread
+from typing import List, Optional, Union
+
+import cv2
+import numpy as np
+from mmcv import color_val
+
+from mmpose.utils.timer import RunningAverage
+from .builder import NODES
+from .node import Node
+
+try:
+ import psutil
+ psutil_proc = psutil.Process()
+except (ImportError, ModuleNotFoundError):
+ psutil_proc = None
+
+
+@NODES.register_module()
+class ModelResultBindingNode(Node):
+
+ def __init__(self, name: str, frame_buffer: str, result_buffer: str,
+ output_buffer: Union[str, List[str]]):
+ super().__init__(name=name, enable=True)
+ self.synchronous = None
+
+ # Cache the latest model result
+ self.last_result_msg = None
+ self.last_output_msg = None
+
+ # Inference speed analysis
+ self.frame_fps = RunningAverage(window=10)
+ self.frame_lag = RunningAverage(window=10)
+ self.result_fps = RunningAverage(window=10)
+ self.result_lag = RunningAverage(window=10)
+
+ # Register buffers
+ # Note that essential buffers will be set in set_runner() because
+ # it depends on the runner.synchronous attribute.
+ self.register_input_buffer(result_buffer, 'result', essential=False)
+ self.register_input_buffer(frame_buffer, 'frame', essential=False)
+ self.register_output_buffer(output_buffer)
+
+ def set_runner(self, runner):
+ super().set_runner(runner)
+
+ # Set synchronous according to the runner
+ if runner.synchronous:
+ self.synchronous = True
+ essential_input = 'result'
+ else:
+ self.synchronous = False
+ essential_input = 'frame'
+
+ # Set essential input buffer according to the synchronous setting
+ for buffer_info in self._input_buffers:
+ if buffer_info.input_name == essential_input:
+ buffer_info.essential = True
+
+ def process(self, input_msgs):
+ result_msg = input_msgs['result']
+
+ # Update last result
+ if result_msg is not None:
+ # Update result FPS
+ if self.last_result_msg is not None:
+ self.result_fps.update(
+ 1.0 /
+ (result_msg.timestamp - self.last_result_msg.timestamp))
+ # Update inference latency
+ self.result_lag.update(time.time() - result_msg.timestamp)
+ # Update last inference result
+ self.last_result_msg = result_msg
+
+ if not self.synchronous:
+ # Asynchronous mode: Bind the latest result with the current frame.
+ frame_msg = input_msgs['frame']
+
+ self.frame_lag.update(time.time() - frame_msg.timestamp)
+
+ # Bind result to frame
+ if self.last_result_msg is not None:
+ frame_msg.set_full_results(
+ self.last_result_msg.get_full_results())
+ frame_msg.merge_route_info(
+ self.last_result_msg.get_route_info())
+
+ output_msg = frame_msg
+
+ else:
+ # Synchronous mode: Directly output the frame that the model result
+ # was obtained from.
+ self.frame_lag.update(time.time() - result_msg.timestamp)
+ output_msg = result_msg
+
+ # Update frame fps and lag
+ if self.last_output_msg is not None:
+ self.frame_lag.update(time.time() - output_msg.timestamp)
+ self.frame_fps.update(
+ 1.0 / (output_msg.timestamp - self.last_output_msg.timestamp))
+ self.last_output_msg = output_msg
+
+ return output_msg
+
+ def _get_node_info(self):
+ info = super()._get_node_info()
+ info['result_fps'] = self.result_fps.average()
+ info['result_lag (ms)'] = self.result_lag.average() * 1000
+ info['frame_fps'] = self.frame_fps.average()
+ info['frame_lag (ms)'] = self.frame_lag.average() * 1000
+ return info
+
+
+@NODES.register_module()
+class MonitorNode(Node):
+
+ _default_ignore_items = ['timestamp']
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = False,
+ x_offset=20,
+ y_offset=20,
+ y_delta=15,
+ text_color='black',
+ background_color=(255, 183, 0),
+ text_scale=0.4,
+ ignore_items: Optional[List[str]] = None):
+ super().__init__(name=name, enable_key=enable_key, enable=enable)
+
+ self.x_offset = x_offset
+ self.y_offset = y_offset
+ self.y_delta = y_delta
+ self.text_color = color_val(text_color)
+ self.background_color = color_val(background_color)
+ self.text_scale = text_scale
+ if ignore_items is None:
+ self.ignore_items = self._default_ignore_items
+ else:
+ self.ignore_items = ignore_items
+
+ self.register_input_buffer(frame_buffer, 'frame', essential=True)
+ self.register_output_buffer(output_buffer)
+
+ def process(self, input_msgs):
+ frame_msg = input_msgs['frame']
+
+ frame_msg.update_route_info(
+ node_name='System Info',
+ node_type='dummy',
+ info=self._get_system_info())
+
+ img = frame_msg.get_image()
+ route_info = frame_msg.get_route_info()
+ img = self._show_route_info(img, route_info)
+
+ frame_msg.set_image(img)
+ return frame_msg
+
+ def _get_system_info(self):
+ sys_info = {}
+ if psutil_proc is not None:
+ sys_info['CPU(%)'] = psutil_proc.cpu_percent()
+ sys_info['Memory(%)'] = psutil_proc.memory_percent()
+ return sys_info
+
+ def _show_route_info(self, img, route_info):
+ canvas = np.full(img.shape, self.background_color, dtype=img.dtype)
+
+ x = self.x_offset
+ y = self.y_offset
+
+ max_len = 0
+
+ def _put_line(line=''):
+ nonlocal y, max_len
+ cv2.putText(canvas, line, (x, y), cv2.FONT_HERSHEY_DUPLEX,
+ self.text_scale, self.text_color, 1)
+ y += self.y_delta
+ max_len = max(max_len, len(line))
+
+ for node_info in route_info:
+ title = f'{node_info["node"]}({node_info["node_type"]})'
+ _put_line(title)
+ for k, v in node_info['info'].items():
+ if k in self.ignore_items:
+ continue
+ if isinstance(v, float):
+ v = f'{v:.1f}'
+ _put_line(f' {k}: {v}')
+
+ x1 = max(0, self.x_offset)
+ x2 = min(img.shape[1], int(x + max_len * self.text_scale * 20))
+ y1 = max(0, self.y_offset - self.y_delta)
+ y2 = min(img.shape[0], y)
+
+ src1 = canvas[y1:y2, x1:x2]
+ src2 = img[y1:y2, x1:x2]
+ img[y1:y2, x1:x2] = cv2.addWeighted(src1, 0.5, src2, 0.5, 0)
+
+ return img
+
+ def bypass(self, input_msgs):
+ return input_msgs['frame']
+
+
+@NODES.register_module()
+class RecorderNode(Node):
+ """Record the frames into a local file."""
+
+ def __init__(
+ self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ out_video_file: str,
+ out_video_fps: int = 30,
+ out_video_codec: str = 'mp4v',
+ buffer_size: int = 30,
+ ):
+ super().__init__(name=name, enable_key=None, enable=True)
+
+ self.queue = Queue(maxsize=buffer_size)
+ self.out_video_file = out_video_file
+ self.out_video_fps = out_video_fps
+ self.out_video_codec = out_video_codec
+ self.vwriter = None
+
+ # Register buffers
+ self.register_input_buffer(frame_buffer, 'frame', essential=True)
+ self.register_output_buffer(output_buffer)
+
+ # Start a new thread to write frame
+ self.t_record = Thread(target=self._record, args=(), daemon=True)
+ self.t_record.start()
+
+ def process(self, input_msgs):
+
+ frame_msg = input_msgs['frame']
+ img = frame_msg.get_image() if frame_msg is not None else None
+ img_queued = False
+
+ while not img_queued:
+ try:
+ self.queue.put(img, timeout=1)
+ img_queued = True
+ logging.info(f'{self.name}: recorder received one frame!')
+ except Full:
+ logging.info(f'{self.name}: recorder jamed!')
+
+ return frame_msg
+
+ def _record(self):
+
+ while True:
+
+ img = self.queue.get()
+
+ if img is None:
+ break
+
+ if self.vwriter is None:
+ fourcc = cv2.VideoWriter_fourcc(*self.out_video_codec)
+ fps = self.out_video_fps
+ frame_size = (img.shape[1], img.shape[0])
+ self.vwriter = cv2.VideoWriter(self.out_video_file, fourcc,
+ fps, frame_size)
+ assert self.vwriter.isOpened()
+
+ self.vwriter.write(img)
+
+ logging.info('Video recorder released!')
+ if self.vwriter is not None:
+ self.vwriter.release()
+
+ def on_exit(self):
+ try:
+ # Try putting a None into the output queue so the self.vwriter will
+ # be released after all queue frames have been written to file.
+ self.queue.put(None, timeout=1)
+ self.t_record.join(timeout=1)
+ except Full:
+ pass
+
+ if self.t_record.is_alive():
+ # Force to release self.vwriter
+ logging.info('Video recorder forced release!')
+ if self.vwriter is not None:
+ self.vwriter.release()
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/mmdet_node.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/mmdet_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..4207647c927dfbd34af225454ed5c2ef7466a012
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/mmdet_node.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Union
+
+from .builder import NODES
+from .node import Node
+
+try:
+ from mmdet.apis import inference_detector, init_detector
+ has_mmdet = True
+except (ImportError, ModuleNotFoundError):
+ has_mmdet = False
+
+
+@NODES.register_module()
+class DetectorNode(Node):
+
+ def __init__(self,
+ name: str,
+ model_config: str,
+ model_checkpoint: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ device: str = 'cuda:0'):
+ # Check mmdetection is installed
+ assert has_mmdet, 'Please install mmdet to run the demo.'
+ super().__init__(name=name, enable_key=enable_key, enable=True)
+
+ self.model_config = model_config
+ self.model_checkpoint = model_checkpoint
+ self.device = device.lower()
+
+ # Init model
+ self.model = init_detector(
+ self.model_config,
+ self.model_checkpoint,
+ device=self.device.lower())
+
+ # Register buffers
+ self.register_input_buffer(input_buffer, 'input', essential=True)
+ self.register_output_buffer(output_buffer)
+
+ def bypass(self, input_msgs):
+ return input_msgs['input']
+
+ def process(self, input_msgs):
+ input_msg = input_msgs['input']
+
+ img = input_msg.get_image()
+
+ preds = inference_detector(self.model, img)
+ det_result = self._post_process(preds)
+
+ input_msg.add_detection_result(det_result, tag=self.name)
+ return input_msg
+
+ def _post_process(self, preds):
+ if isinstance(preds, tuple):
+ dets = preds[0]
+ segms = preds[1]
+ else:
+ dets = preds
+ segms = [None] * len(dets)
+
+ assert len(dets) == len(self.model.CLASSES)
+ assert len(segms) == len(self.model.CLASSES)
+ result = {'preds': [], 'model_cfg': self.model.cfg.copy()}
+
+ for i, (cls_name, bboxes,
+ masks) in enumerate(zip(self.model.CLASSES, dets, segms)):
+ if masks is None:
+ masks = [None] * len(bboxes)
+ else:
+ assert len(masks) == len(bboxes)
+
+ preds_i = [{
+ 'cls_id': i,
+ 'label': cls_name,
+ 'bbox': bbox,
+ 'mask': mask,
+ } for (bbox, mask) in zip(bboxes, masks)]
+ result['preds'].extend(preds_i)
+
+ return result
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/mmpose_node.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/mmpose_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..167d7413ea48943b9373525bf5f392b5f1aa248b
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/mmpose_node.py
@@ -0,0 +1,122 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+from typing import Dict, List, Optional, Union
+
+from mmpose.apis import (get_track_id, inference_top_down_pose_model,
+ init_pose_model)
+from ..utils import Message
+from .builder import NODES
+from .node import Node
+
+
+@NODES.register_module()
+class TopDownPoseEstimatorNode(Node):
+
+ def __init__(self,
+ name: str,
+ model_config: str,
+ model_checkpoint: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ device: str = 'cuda:0',
+ cls_ids: Optional[List] = None,
+ cls_names: Optional[List] = None,
+ bbox_thr: float = 0.5):
+ super().__init__(name=name, enable_key=enable_key, enable=enable)
+
+ # Init model
+ self.model_config = model_config
+ self.model_checkpoint = model_checkpoint
+ self.device = device.lower()
+
+ self.cls_ids = cls_ids
+ self.cls_names = cls_names
+ self.bbox_thr = bbox_thr
+
+ # Init model
+ self.model = init_pose_model(
+ self.model_config,
+ self.model_checkpoint,
+ device=self.device.lower())
+
+ # Store history for pose tracking
+ self.track_info = {
+ 'next_id': 0,
+ 'last_pose_preds': [],
+ 'last_time': None
+ }
+
+ # Register buffers
+ self.register_input_buffer(input_buffer, 'input', essential=True)
+ self.register_output_buffer(output_buffer)
+
+ def bypass(self, input_msgs):
+ return input_msgs['input']
+
+ def process(self, input_msgs: Dict[str, Message]) -> Message:
+
+ input_msg = input_msgs['input']
+ img = input_msg.get_image()
+ det_results = input_msg.get_detection_results()
+
+ if det_results is None:
+ raise ValueError(
+ 'No detection results are found in the frame message.'
+ f'{self.__class__.__name__} should be used after a '
+ 'detector node.')
+
+ full_det_preds = []
+ for det_result in det_results:
+ det_preds = det_result['preds']
+ if self.cls_ids:
+ # Filter detection results by class ID
+ det_preds = [
+ p for p in det_preds if p['cls_id'] in self.cls_ids
+ ]
+ elif self.cls_names:
+ # Filter detection results by class name
+ det_preds = [
+ p for p in det_preds if p['label'] in self.cls_names
+ ]
+ full_det_preds.extend(det_preds)
+
+ # Inference pose
+ pose_preds, _ = inference_top_down_pose_model(
+ self.model,
+ img,
+ full_det_preds,
+ bbox_thr=self.bbox_thr,
+ format='xyxy')
+
+ # Pose tracking
+ current_time = time.time()
+ if self.track_info['last_time'] is None:
+ fps = None
+ elif self.track_info['last_time'] >= current_time:
+ fps = None
+ else:
+ fps = 1.0 / (current_time - self.track_info['last_time'])
+
+ pose_preds, next_id = get_track_id(
+ pose_preds,
+ self.track_info['last_pose_preds'],
+ self.track_info['next_id'],
+ use_oks=False,
+ tracking_thr=0.3,
+ use_one_euro=True,
+ fps=fps)
+
+ self.track_info['next_id'] = next_id
+ self.track_info['last_pose_preds'] = pose_preds.copy()
+ self.track_info['last_time'] = current_time
+
+ pose_result = {
+ 'preds': pose_preds,
+ 'model_cfg': self.model.cfg.copy(),
+ }
+
+ input_msg.add_pose_result(pose_result, tag=self.name)
+
+ return input_msg
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/node.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/node.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e48d089dd18f8845125f50676cc175dbc2d24d
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/node.py
@@ -0,0 +1,372 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import time
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass
+from queue import Empty
+from threading import Thread
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+from mmcv.utils.misc import is_method_overridden
+
+from mmpose.utils import StopWatch
+from ..utils import Message, VideoEndingMessage, limit_max_fps
+
+
+@dataclass
+class BufferInfo():
+ """Dataclass for buffer information."""
+ buffer_name: str
+ input_name: Optional[str] = None
+ essential: bool = False
+
+
+@dataclass
+class EventInfo():
+ """Dataclass for event handler information."""
+ event_name: str
+ is_keyboard: bool = False
+ handler_func: Optional[Callable] = None
+
+
+class Node(Thread, metaclass=ABCMeta):
+ """Base interface of functional module.
+
+ Parameters:
+ name (str, optional): The node name (also thread name).
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note:
+ 1. If enable_key is set, the bypass method need to be
+ overridden to define the node behavior when disabled
+ 2. Some hot-key has been use for particular use. For example:
+ 'q', 'Q' and 27 are used for quit
+ Default: None
+ max_fps (int): Maximum FPS of the node. This is to avoid the node
+ running unrestrictedly and causing large resource consuming.
+ Default: 30
+ input_check_interval (float): Minimum interval (in millisecond) between
+ checking if input is ready. Default: 0.001
+ enable (bool): Default enable/disable status. Default: True.
+ daemon (bool): Whether node is a daemon. Default: True.
+ """
+
+ def __init__(self,
+ name: Optional[str] = None,
+ enable_key: Optional[Union[str, int]] = None,
+ max_fps: int = 30,
+ input_check_interval: float = 0.01,
+ enable: bool = True,
+ daemon=False):
+ super().__init__(name=name, daemon=daemon)
+ self._runner = None
+ self._enabled = enable
+ self.enable_key = enable_key
+ self.max_fps = max_fps
+ self.input_check_interval = input_check_interval
+
+ # A partitioned buffer manager the runner's buffer manager that
+ # only accesses the buffers related to the node
+ self._buffer_manager = None
+
+ # Input/output buffers are a list of registered buffers' information
+ self._input_buffers = []
+ self._output_buffers = []
+
+ # Event manager is a copy of assigned runner's event manager
+ self._event_manager = None
+
+ # A list of registered event information
+ # See register_event() for more information
+ # Note that we recommend to handle events in nodes by registering
+ # handlers, but one can still access the raw event by _event_manager
+ self._registered_events = []
+
+ # A list of (listener_threads, event_info)
+ # See set_runner() for more information
+ self._event_listener_threads = []
+
+ # A timer to calculate node FPS
+ self._timer = StopWatch(window=10)
+
+ # Register enable toggle key
+ if self.enable_key:
+ # If the node allows toggling enable, it should override the
+ # `bypass` method to define the node behavior when disabled.
+ if not is_method_overridden('bypass', Node, self.__class__):
+ raise NotImplementedError(
+ f'The node {self.__class__} does not support toggling'
+ 'enable but got argument `enable_key`. To support toggling'
+ 'enable, please override the `bypass` method of the node.')
+
+ self.register_event(
+ event_name=self.enable_key,
+ is_keyboard=True,
+ handler_func=self._toggle_enable,
+ )
+
+ @property
+ def registered_buffers(self):
+ return self._input_buffers + self._output_buffers
+
+ @property
+ def registered_events(self):
+ return self._registered_events.copy()
+
+ def _toggle_enable(self):
+ self._enabled = not self._enabled
+
+ def register_input_buffer(self,
+ buffer_name: str,
+ input_name: str,
+ essential: bool = False):
+ """Register an input buffer, so that Node can automatically check if
+ data is ready, fetch data from the buffers and format the inputs to
+ feed into `process` method.
+
+ This method can be invoked multiple times to register multiple input
+ buffers.
+
+ The subclass of Node should invoke `register_input_buffer` in its
+ `__init__` method.
+
+ Args:
+ buffer_name (str): The name of the buffer
+ input_name (str): The name of the fetched message from the
+ corresponding buffer
+ essential (bool): An essential input means the node will wait
+ until the input is ready before processing. Otherwise, an
+ inessential input will not block the processing, instead
+ a None will be fetched if the buffer is not ready.
+ """
+ buffer_info = BufferInfo(buffer_name, input_name, essential)
+ self._input_buffers.append(buffer_info)
+
+ def register_output_buffer(self, buffer_name: Union[str, List[str]]):
+ """Register one or multiple output buffers, so that the Node can
+ automatically send the output of the `process` method to these buffers.
+
+ The subclass of Node should invoke `register_output_buffer` in its
+ `__init__` method.
+
+ Args:
+ buffer_name (str|list): The name(s) of the output buffer(s).
+ """
+
+ if not isinstance(buffer_name, list):
+ buffer_name = [buffer_name]
+
+ for name in buffer_name:
+ buffer_info = BufferInfo(name)
+ self._output_buffers.append(buffer_info)
+
+ def register_event(self,
+ event_name: str,
+ is_keyboard: bool = False,
+ handler_func: Optional[Callable] = None):
+ """Register an event. All events used in the node need to be registered
+ in __init__(). If a callable handler is given, a thread will be create
+ to listen and handle the event when the node starts.
+
+ Args:
+ Args:
+ event_name (str|int): The event name. If is_keyboard==True,
+ event_name should be a str (as char) or an int (as ascii)
+ is_keyboard (bool): Indicate whether it is an keyboard
+ event. If True, the argument event_name will be regarded as a
+ key indicator.
+ handler_func (callable, optional): The event handler function,
+ which should be a collable object with no arguments or
+ return values. Default: None.
+ """
+ event_info = EventInfo(event_name, is_keyboard, handler_func)
+ self._registered_events.append(event_info)
+
+ def set_runner(self, runner):
+ # Get partitioned buffer manager
+ buffer_names = [
+ buffer.buffer_name
+ for buffer in self._input_buffers + self._output_buffers
+ ]
+ self._buffer_manager = runner.buffer_manager.get_sub_manager(
+ buffer_names)
+
+ # Get event manager
+ self._event_manager = runner.event_manager
+
+ def _get_input_from_buffer(self) -> Tuple[bool, Optional[Dict]]:
+ """Get and pack input data if it's ready. The function returns a tuple
+ of a status flag and a packed data dictionary. If input_buffer is
+ ready, the status flag will be True, and the packed data is a dict
+ whose items are buffer names and corresponding messages (unready
+ additional buffers will give a `None`). Otherwise, the status flag is
+ False and the packed data is None.
+
+ Returns:
+ bool: status flag
+ dict[str, Message]: the packed inputs where the key is the buffer
+ name and the value is the Message got from the corresponding
+ buffer.
+ """
+ buffer_manager = self._buffer_manager
+
+ if buffer_manager is None:
+ raise ValueError(f'{self.name}: Runner not set!')
+
+ # Check that essential buffers are ready
+ for buffer_info in self._input_buffers:
+ if buffer_info.essential and buffer_manager.is_empty(
+ buffer_info.buffer_name):
+ return False, None
+
+ # Default input
+ result = {
+ buffer_info.input_name: None
+ for buffer_info in self._input_buffers
+ }
+
+ for buffer_info in self._input_buffers:
+ try:
+ result[buffer_info.input_name] = buffer_manager.get(
+ buffer_info.buffer_name, block=False)
+ except Empty:
+ if buffer_info.essential:
+ # Return unsuccessful flag if any
+ # essential input is unready
+ return False, None
+
+ return True, result
+
+ def _send_output_to_buffers(self, output_msg):
+ """Send output of the process method to registered output buffers.
+
+ Args:
+ output_msg (Message): output message
+ force (bool, optional): If True, block until the output message
+ has been put into all output buffers. Default: False
+ """
+ for buffer_info in self._output_buffers:
+ buffer_name = buffer_info.buffer_name
+ self._buffer_manager.put_force(buffer_name, output_msg)
+
+ @abstractmethod
+ def process(self, input_msgs: Dict[str, Message]) -> Union[Message, None]:
+ """The core method that implement the function of the node. This method
+ will be invoked when the node is enabled and the input data is ready.
+
+ All subclasses of Node should override this method.
+
+ Args:
+ input_msgs (dict): The input data collected from the buffers. For
+ each item, the key is the `input_name` of the registered input
+ buffer, while the value is a Message instance fetched from the
+ buffer (or None if the buffer is unessential and not ready).
+
+ Returns:
+ Message: The output message of the node. It will be send to all
+ registered output buffers.
+ """
+
+ def bypass(self, input_msgs: Dict[str, Message]) -> Union[Message, None]:
+ """The method that defines the node behavior when disabled. Note that
+ if the node has an `enable_key`, this method should be override.
+
+ The method input/output is same as it of `process` method.
+
+ Args:
+ input_msgs (dict): The input data collected from the buffers. For
+ each item, the key is the `input_name` of the registered input
+ buffer, while the value is a Message instance fetched from the
+ buffer (or None if the buffer is unessential and not ready).
+
+ Returns:
+ Message: The output message of the node. It will be send to all
+ registered output buffers.
+ """
+ raise NotImplementedError
+
+ def _get_node_info(self):
+ """Get route information of the node."""
+ info = {'fps': self._timer.report('_FPS_'), 'timestamp': time.time()}
+ return info
+
+ def on_exit(self):
+ """This method will be invoked on event `_exit_`.
+
+ Subclasses should override this method to specifying the exiting
+ behavior.
+ """
+
+ def run(self):
+ """Method representing the Node's activity.
+
+ This method override the standard run() method of Thread. Users should
+ not override this method in subclasses.
+ """
+
+ logging.info(f'Node {self.name} starts')
+
+ # Create event listener threads
+ for event_info in self._registered_events:
+
+ if event_info.handler_func is None:
+ continue
+
+ def event_listener():
+ while True:
+ with self._event_manager.wait_and_handle(
+ event_info.event_name, event_info.is_keyboard):
+ event_info.handler_func()
+
+ t_listener = Thread(target=event_listener, args=(), daemon=True)
+ t_listener.start()
+ self._event_listener_threads.append(t_listener)
+
+ # Loop
+ while True:
+ # Exit
+ if self._event_manager.is_set('_exit_'):
+ self.on_exit()
+ break
+
+ # Check if input is ready
+ input_status, input_msgs = self._get_input_from_buffer()
+
+ # Input is not ready
+ if not input_status:
+ time.sleep(self.input_check_interval)
+ continue
+
+ # If a VideoEndingMessage is received, broadcast the signal
+ # without invoking process() or bypass()
+ video_ending = False
+ for _, msg in input_msgs.items():
+ if isinstance(msg, VideoEndingMessage):
+ self._send_output_to_buffers(msg)
+ video_ending = True
+ break
+
+ if video_ending:
+ self.on_exit()
+ break
+
+ # Check if enabled
+ if not self._enabled:
+ # Override bypass method to define node behavior when disabled
+ output_msg = self.bypass(input_msgs)
+ else:
+ with self._timer.timeit():
+ with limit_max_fps(self.max_fps):
+ # Process
+ output_msg = self.process(input_msgs)
+
+ if output_msg:
+ # Update route information
+ node_info = self._get_node_info()
+ output_msg.update_route_info(node=self, info=node_info)
+
+ # Send output message
+ if output_msg is not None:
+ self._send_output_to_buffers(output_msg)
+
+ logging.info(f'{self.name}: process ending.')
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/valentinemagic_node.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/valentinemagic_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b1c6a585065416b50f1c889272d7e869942354e
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/valentinemagic_node.py
@@ -0,0 +1,340 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+
+from ..utils import (FrameMessage, get_eye_keypoint_ids, get_hand_keypoint_ids,
+ get_mouth_keypoint_ids, load_image_from_disk_or_url)
+from .builder import NODES
+from .frame_drawing_node import FrameDrawingNode
+
+
+@dataclass
+class HeartInfo():
+ """Dataclass for heart information."""
+ heart_type: int
+ start_time: float
+ start_pos: Tuple[int, int]
+ end_pos: Tuple[int, int]
+
+
+@NODES.register_module()
+class ValentineMagicNode(FrameDrawingNode):
+
+ def __init__(self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ kpt_vis_thr: float = 0.3,
+ hand_heart_angle_thr: float = 90.0,
+ longest_duration: float = 2.0,
+ largest_ratio: float = 0.25,
+ hand_heart_img_path: Optional[str] = None,
+ flying_heart_img_path: Optional[str] = None,
+ hand_heart_dis_ratio_thr: float = 1.0,
+ flying_heart_dis_ratio_thr: float = 3.5,
+ num_persons: int = 2):
+
+ super().__init__(
+ name, frame_buffer, output_buffer, enable_key=enable_key)
+
+ if hand_heart_img_path is None:
+ hand_heart_img_path = 'https://user-images.githubusercontent.com/'\
+ '87690686/149731850-ea946766-a4e8-4efa-82f5'\
+ '-e2f0515db8ae.png'
+ if flying_heart_img_path is None:
+ flying_heart_img_path = 'https://user-images.githubusercontent.'\
+ 'com/87690686/153554948-937ce496-33dd-4'\
+ '9ab-9829-0433fd7c13c4.png'
+
+ self.hand_heart = load_image_from_disk_or_url(hand_heart_img_path)
+ self.flying_heart = load_image_from_disk_or_url(flying_heart_img_path)
+
+ self.kpt_vis_thr = kpt_vis_thr
+ self.hand_heart_angle_thr = hand_heart_angle_thr
+ self.hand_heart_dis_ratio_thr = hand_heart_dis_ratio_thr
+ self.flying_heart_dis_ratio_thr = flying_heart_dis_ratio_thr
+ self.longest_duration = longest_duration
+ self.largest_ratio = largest_ratio
+ self.num_persons = num_persons
+
+ # record the heart infos for each person
+ self.heart_infos = {}
+
+ def _cal_distance(self, p1: np.ndarray, p2: np.ndarray) -> np.float64:
+ """calculate the distance of points p1 and p2."""
+ return np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
+
+ def _cal_angle(self, p1: np.ndarray, p2: np.ndarray, p3: np.ndarray,
+ p4: np.ndarray) -> np.float64:
+ """calculate the angle of vectors v1(constructed by points p2 and p1)
+ and v2(constructed by points p4 and p3)"""
+ v1 = p2 - p1
+ v2 = p4 - p3
+
+ vector_prod = v1[0] * v2[0] + v1[1] * v2[1]
+ length_prod = np.sqrt(pow(v1[0], 2) + pow(v1[1], 2)) * np.sqrt(
+ pow(v2[0], 2) + pow(v2[1], 2))
+ cos = vector_prod * 1.0 / (length_prod * 1.0 + 1e-6)
+
+ return (np.arccos(cos) / np.pi) * 180
+
+ def _check_heart(self, pred: Dict[str,
+ np.ndarray], hand_indices: List[int],
+ mouth_index: int, eye_indices: List[int]) -> int:
+ """Check the type of Valentine Magic based on the pose results and
+ keypoint indices of hand, mouth. and eye.
+
+ Args:
+ pred(dict): The pose estimation results containing:
+ - "keypoints" (np.ndarray[K,3]): keypoint detection result
+ in [x, y, score]
+ hand_indices(list[int]): keypoint indices of hand
+ mouth_index(int): keypoint index of mouth
+ eye_indices(list[int]): keypoint indices of eyes
+
+ Returns:
+ int: a number representing the type of heart pose,
+ 0: None, 1: hand heart, 2: left hand blow kiss,
+ 3: right hand blow kiss
+ """
+ kpts = pred['keypoints']
+
+ left_eye_idx, right_eye_idx = eye_indices
+ left_eye_pos = kpts[left_eye_idx][:2]
+ right_eye_pos = kpts[right_eye_idx][:2]
+ eye_dis = self._cal_distance(left_eye_pos, right_eye_pos)
+
+ # these indices are corresoponding to the following keypoints:
+ # left_hand_root, left_pinky_finger1,
+ # left_pinky_finger3, left_pinky_finger4,
+ # right_hand_root, right_pinky_finger1
+ # right_pinky_finger3, right_pinky_finger4
+
+ both_hands_vis = True
+ for i in [0, 17, 19, 20, 21, 38, 40, 41]:
+ if kpts[hand_indices[i]][2] < self.kpt_vis_thr:
+ both_hands_vis = False
+
+ if both_hands_vis:
+ p1 = kpts[hand_indices[20]][:2]
+ p2 = kpts[hand_indices[19]][:2]
+ p3 = kpts[hand_indices[17]][:2]
+ p4 = kpts[hand_indices[0]][:2]
+ left_angle = self._cal_angle(p1, p2, p3, p4)
+
+ p1 = kpts[hand_indices[41]][:2]
+ p2 = kpts[hand_indices[40]][:2]
+ p3 = kpts[hand_indices[38]][:2]
+ p4 = kpts[hand_indices[21]][:2]
+ right_angle = self._cal_angle(p1, p2, p3, p4)
+
+ hand_dis = self._cal_distance(kpts[hand_indices[20]][:2],
+ kpts[hand_indices[41]][:2])
+
+ if (left_angle < self.hand_heart_angle_thr
+ and right_angle < self.hand_heart_angle_thr
+ and hand_dis / eye_dis < self.hand_heart_dis_ratio_thr):
+ return 1
+
+ # these indices are corresoponding to the following keypoints:
+ # left_middle_finger1, left_middle_finger4,
+ left_hand_vis = True
+ for i in [9, 12]:
+ if kpts[hand_indices[i]][2] < self.kpt_vis_thr:
+ left_hand_vis = False
+ break
+ # right_middle_finger1, right_middle_finger4
+
+ right_hand_vis = True
+ for i in [30, 33]:
+ if kpts[hand_indices[i]][2] < self.kpt_vis_thr:
+ right_hand_vis = False
+ break
+
+ mouth_vis = True
+ if kpts[mouth_index][2] < self.kpt_vis_thr:
+ mouth_vis = False
+
+ if (not left_hand_vis and not right_hand_vis) or not mouth_vis:
+ return 0
+
+ mouth_pos = kpts[mouth_index]
+
+ left_mid_hand_pos = (kpts[hand_indices[9]][:2] +
+ kpts[hand_indices[12]][:2]) / 2
+ lefthand_mouth_dis = self._cal_distance(left_mid_hand_pos, mouth_pos)
+
+ if lefthand_mouth_dis / eye_dis < self.flying_heart_dis_ratio_thr:
+ return 2
+
+ right_mid_hand_pos = (kpts[hand_indices[30]][:2] +
+ kpts[hand_indices[33]][:2]) / 2
+ righthand_mouth_dis = self._cal_distance(right_mid_hand_pos, mouth_pos)
+
+ if righthand_mouth_dis / eye_dis < self.flying_heart_dis_ratio_thr:
+ return 3
+
+ return 0
+
+ def _get_heart_route(self, heart_type: int, cur_pred: Dict[str,
+ np.ndarray],
+ tar_pred: Dict[str,
+ np.ndarray], hand_indices: List[int],
+ mouth_index: int) -> Tuple[int, int]:
+ """get the start and end position of the heart, based on two keypoint
+ results and keypoint indices of hand and mouth.
+
+ Args:
+ cur_pred(dict): The pose estimation results of current person,
+ containing: the following keys:
+ - "keypoints" (np.ndarray[K,3]): keypoint detection result
+ in [x, y, score]
+ tar_pred(dict): The pose estimation results of target person,
+ containing: the following keys:
+ - "keypoints" (np.ndarray[K,3]): keypoint detection result
+ in [x, y, score]
+ hand_indices(list[int]): keypoint indices of hand
+ mouth_index(int): keypoint index of mouth
+
+ Returns:
+ tuple(int): the start position of heart
+ tuple(int): the end position of heart
+ """
+ cur_kpts = cur_pred['keypoints']
+
+ assert heart_type in [1, 2,
+ 3], 'Can not determine the type of heart effect'
+
+ if heart_type == 1:
+ p1 = cur_kpts[hand_indices[20]][:2]
+ p2 = cur_kpts[hand_indices[41]][:2]
+ elif heart_type == 2:
+ p1 = cur_kpts[hand_indices[9]][:2]
+ p2 = cur_kpts[hand_indices[12]][:2]
+ elif heart_type == 3:
+ p1 = cur_kpts[hand_indices[30]][:2]
+ p2 = cur_kpts[hand_indices[33]][:2]
+
+ cur_x, cur_y = (p1 + p2) / 2
+ # the mid point of two fingers
+ start_pos = (int(cur_x), int(cur_y))
+
+ tar_kpts = tar_pred['keypoints']
+ end_pos = tar_kpts[mouth_index][:2]
+
+ return start_pos, end_pos
+
+ def _draw_heart(self, canvas: np.ndarray, heart_info: HeartInfo,
+ t_pass: float) -> np.ndarray:
+ """draw the heart according to heart info and time."""
+ start_x, start_y = heart_info.start_pos
+ end_x, end_y = heart_info.end_pos
+
+ scale = t_pass / self.longest_duration
+
+ max_h, max_w = canvas.shape[:2]
+ hm, wm = self.largest_ratio * max_h, self.largest_ratio * max_h
+ new_h, new_w = int(hm * scale), int(wm * scale)
+
+ x = int(start_x + scale * (end_x - start_x))
+ y = int(start_y + scale * (end_y - start_y))
+
+ y1 = max(0, y - int(new_h / 2))
+ y2 = min(max_h - 1, y + int(new_h / 2))
+
+ x1 = max(0, x - int(new_w / 2))
+ x2 = min(max_w - 1, x + int(new_w / 2))
+
+ target = canvas[y1:y2 + 1, x1:x2 + 1].copy()
+ new_h, new_w = target.shape[:2]
+
+ if new_h == 0 or new_w == 0:
+ return canvas
+
+ assert heart_info.heart_type in [
+ 1, 2, 3
+ ], 'Can not determine the type of heart effect'
+ if heart_info.heart_type == 1: # hand heart
+ patch = self.hand_heart.copy()
+ elif heart_info.heart_type >= 2: # hand blow kiss
+ patch = self.flying_heart.copy()
+ if heart_info.start_pos[0] > heart_info.end_pos[0]:
+ patch = patch[:, ::-1]
+
+ patch = cv2.resize(patch, (new_w, new_h))
+ mask = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
+ mask = (mask < 100)[..., None].astype(np.float32) * 0.8
+
+ canvas[y1:y2 + 1, x1:x2 + 1] = patch * mask + target * (1 - mask)
+
+ return canvas
+
+ def draw(self, frame_msg: FrameMessage) -> np.ndarray:
+ canvas = frame_msg.get_image()
+
+ pose_results = frame_msg.get_pose_results()
+ if not pose_results:
+ return canvas
+
+ for pose_result in pose_results:
+ model_cfg = pose_result['model_cfg']
+
+ preds = [pred.copy() for pred in pose_result['preds']]
+ # if number of persons in the image is less than 2,
+ # no heart effect will be triggered
+ if len(preds) < self.num_persons:
+ continue
+
+ # if number of persons in the image is more than 2,
+ # only use the first two pose results
+ preds = preds[:self.num_persons]
+ ids = [preds[i]['track_id'] for i in range(self.num_persons)]
+
+ for id in self.heart_infos.copy():
+ if id not in ids:
+ # if the id of a person not in previous heart_infos,
+ # delete the corresponding field
+ del self.heart_infos[id]
+
+ for i in range(self.num_persons):
+ id = preds[i]['track_id']
+
+ # if the predicted person in previous heart_infos,
+ # draw the heart
+ if id in self.heart_infos.copy():
+ t_pass = time.time() - self.heart_infos[id].start_time
+
+ # the time passed since last heart pose less than
+ # longest_duration, continue to draw the heart
+ if t_pass < self.longest_duration:
+ canvas = self._draw_heart(canvas, self.heart_infos[id],
+ t_pass)
+ # reset corresponding heart info
+ else:
+ del self.heart_infos[id]
+ else:
+ hand_indices = get_hand_keypoint_ids(model_cfg)
+ mouth_index = get_mouth_keypoint_ids(model_cfg)
+ eye_indices = get_eye_keypoint_ids(model_cfg)
+
+ # check the type of Valentine Magic based on pose results
+ # and keypoint indices of hand and mouth
+ heart_type = self._check_heart(preds[i], hand_indices,
+ mouth_index, eye_indices)
+ # trigger a Valentine Magic effect
+ if heart_type:
+ # get the route of heart
+ start_pos, end_pos = self._get_heart_route(
+ heart_type, preds[i],
+ preds[self.num_persons - 1 - i], hand_indices,
+ mouth_index)
+ start_time = time.time()
+ self.heart_infos[id] = HeartInfo(
+ heart_type, start_time, start_pos, end_pos)
+
+ return canvas
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/xdwendwen_node.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/xdwendwen_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a0914d3bf473f278023ed1569ae18d6d1b5fcf3
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/nodes/xdwendwen_node.py
@@ -0,0 +1,240 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+from dataclasses import dataclass
+from typing import List, Tuple, Union
+
+import cv2
+import numpy as np
+
+from mmpose.datasets.dataset_info import DatasetInfo
+from ..utils import load_image_from_disk_or_url
+from .builder import NODES
+from .frame_drawing_node import FrameDrawingNode
+
+
+@dataclass
+class DynamicInfo:
+ pos_curr: Tuple[int, int] = (0, 0)
+ pos_step: Tuple[int, int] = (0, 0)
+ step_curr: int = 0
+
+
+@NODES.register_module()
+class XDwenDwenNode(FrameDrawingNode):
+ """An effect drawing node that captures the face of a cat or dog and blend
+ it into a Bing-Dwen-Dwen (the mascot of 2022 Beijing Winter Olympics).
+
+ Parameters:
+ name (str, optional): The node name (also thread name).
+ frame_buffer (str): The name of the input buffer.
+ output_buffer (str | list): The name(s) of the output buffer(s).
+ mode_key (str | int): A hot key to switch the background image.
+ resource_file (str): The annotation file of resource images, which
+ should be in Labelbee format and contain both facial keypoint and
+ region annotations.
+ out_shape (tuple): The shape of output frame in (width, height).
+ """
+
+ dynamic_scale = 0.15
+ dynamic_max_step = 15
+
+ def __init__(
+ self,
+ name: str,
+ frame_buffer: str,
+ output_buffer: Union[str, List[str]],
+ mode_key: Union[str, int],
+ resource_file: str,
+ out_shape: Tuple[int, int] = (480, 480),
+ rigid_transform: bool = True,
+ ):
+ super().__init__(name, frame_buffer, output_buffer, enable=True)
+
+ self.mode_key = mode_key
+ self.mode_index = 0
+ self.out_shape = out_shape
+ self.rigid = rigid_transform
+
+ self.latest_pred = None
+
+ self.dynamic_info = DynamicInfo()
+
+ self.register_event(
+ self.mode_key, is_keyboard=True, handler_func=self.switch_mode)
+
+ self._init_resource(resource_file)
+
+ def _init_resource(self, resource_file):
+
+ # The resource_file is a JSON file that contains the facial
+ # keypoint and mask annotation information of the resource files.
+ # The annotations should follow the label-bee standard format.
+ # See https://github.com/open-mmlab/labelbee-client for details.
+ with open(resource_file) as f:
+ anns = json.load(f)
+ resource_infos = []
+
+ for ann in anns:
+ # Load image
+ img = load_image_from_disk_or_url(ann['url'])
+ # Load result
+ rst = json.loads(ann['result'])
+
+ # Check facial keypoint information
+ assert rst['step_1']['toolName'] == 'pointTool'
+ assert len(rst['step_1']['result']) == 3
+
+ keypoints = sorted(
+ rst['step_1']['result'], key=lambda x: x['order'])
+ keypoints = np.array([[pt['x'], pt['y']] for pt in keypoints])
+
+ # Check facial mask
+ assert rst['step_2']['toolName'] == 'polygonTool'
+ assert len(rst['step_2']['result']) == 1
+ assert len(rst['step_2']['result'][0]['pointList']) > 2
+
+ mask_pts = np.array(
+ [[pt['x'], pt['y']]
+ for pt in rst['step_2']['result'][0]['pointList']])
+
+ mul = 1.0 + self.dynamic_scale
+
+ w_scale = self.out_shape[0] / img.shape[1] * mul
+ h_scale = self.out_shape[1] / img.shape[0] * mul
+
+ img = cv2.resize(
+ img,
+ dsize=None,
+ fx=w_scale,
+ fy=h_scale,
+ interpolation=cv2.INTER_CUBIC)
+
+ keypoints *= [w_scale, h_scale]
+ mask_pts *= [w_scale, h_scale]
+
+ mask = cv2.fillPoly(
+ np.zeros(img.shape[:2], dtype=np.uint8),
+ [mask_pts.astype(np.int32)],
+ color=1)
+
+ res = {
+ 'img': img,
+ 'keypoints': keypoints,
+ 'mask': mask,
+ }
+ resource_infos.append(res)
+
+ self.resource_infos = resource_infos
+
+ self._reset_dynamic()
+
+ def switch_mode(self):
+ self.mode_index = (self.mode_index + 1) % len(self.resource_infos)
+
+ def _reset_dynamic(self):
+ x_tar = np.random.randint(int(self.out_shape[0] * self.dynamic_scale))
+ y_tar = np.random.randint(int(self.out_shape[1] * self.dynamic_scale))
+
+ x_step = (x_tar -
+ self.dynamic_info.pos_curr[0]) / self.dynamic_max_step
+ y_step = (y_tar -
+ self.dynamic_info.pos_curr[1]) / self.dynamic_max_step
+
+ self.dynamic_info.pos_step = (x_step, y_step)
+ self.dynamic_info.step_curr = 0
+
+ def draw(self, frame_msg):
+
+ full_pose_results = frame_msg.get_pose_results()
+
+ pred = None
+ if full_pose_results:
+ for pose_results in full_pose_results:
+ if not pose_results['preds']:
+ continue
+
+ pred = pose_results['preds'][0].copy()
+ pred['dataset'] = DatasetInfo(pose_results['model_cfg'].data.
+ test.dataset_info).dataset_name
+
+ self.latest_pred = pred
+ break
+
+ # Use the latest pose result if there is none available in
+ # the current frame.
+ if pred is None:
+ pred = self.latest_pred
+
+ # Get the background image and facial annotations
+ res = self.resource_infos[self.mode_index]
+ img = frame_msg.get_image()
+ canvas = res['img'].copy()
+ mask = res['mask']
+ kpts_tar = res['keypoints']
+
+ if pred is not None:
+ if pred['dataset'] == 'ap10k':
+ # left eye: 0, right eye: 1, nose: 2
+ kpts_src = pred['keypoints'][[0, 1, 2], :2]
+ elif pred['dataset'] == 'coco_wholebody':
+ # left eye: 1, right eye 2, nose: 0
+ kpts_src = pred['keypoints'][[1, 2, 0], :2]
+ else:
+ raise ValueError('Can not obtain face landmark information'
+ f'from dataset: {pred["type"]}')
+
+ trans_mat = self._get_transform(kpts_src, kpts_tar)
+
+ warp = cv2.warpAffine(img, trans_mat, dsize=canvas.shape[:2])
+ cv2.copyTo(warp, mask, canvas)
+
+ # Add random movement to the background
+ xc, yc = self.dynamic_info.pos_curr
+ xs, ys = self.dynamic_info.pos_step
+ w, h = self.out_shape
+
+ x = min(max(int(xc), 0), canvas.shape[1] - w + 1)
+ y = min(max(int(yc), 0), canvas.shape[0] - h + 1)
+
+ canvas = canvas[y:y + h, x:x + w]
+
+ self.dynamic_info.pos_curr = (xc + xs, yc + ys)
+ self.dynamic_info.step_curr += 1
+
+ if self.dynamic_info.step_curr == self.dynamic_max_step:
+ self._reset_dynamic()
+
+ return canvas
+
+ def _get_transform(self, kpts_src, kpts_tar):
+ if self.rigid:
+ # rigid transform
+ n = kpts_src.shape[0]
+ X = np.zeros((n * 2, 4), dtype=np.float32)
+ U = np.zeros((n * 2, 1), dtype=np.float32)
+ X[:n, :2] = kpts_src
+ X[:n, 2] = 1
+ X[n:, 0] = kpts_src[:, 1]
+ X[n:, 1] = -kpts_src[:, 0]
+ X[n:, 3] = 1
+
+ U[:n, 0] = kpts_tar[:, 0]
+ U[n:, 0] = kpts_tar[:, 1]
+
+ M = np.linalg.pinv(X).dot(U).flatten()
+
+ trans_mat = np.array([[M[0], M[1], M[2]], [-M[1], M[0], M[3]]],
+ dtype=np.float32)
+
+ else:
+ # normal affine transform
+ # adaptive horizontal flipping
+ if (np.linalg.norm(kpts_tar[0] - kpts_tar[2]) -
+ np.linalg.norm(kpts_tar[1] - kpts_tar[2])) * (
+ np.linalg.norm(kpts_src[0] - kpts_src[2]) -
+ np.linalg.norm(kpts_src[1] - kpts_src[2])) < 0:
+ kpts_src = kpts_src[[1, 0, 2], :]
+ trans_mat, _ = cv2.estimateAffine2D(
+ kpts_src.astype(np.float32), kpts_tar.astype(np.float32))
+
+ return trans_mat
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/__init__.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d906df0748cd6e5f87642ea6fdc9511e833e22ff
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/__init__.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .buffer import BufferManager
+from .event import EventManager
+from .message import FrameMessage, Message, VideoEndingMessage
+from .misc import (ImageCapture, copy_and_paste, expand_and_clamp,
+ get_cached_file_path, is_image_file, limit_max_fps,
+ load_image_from_disk_or_url, screen_matting)
+from .pose import (get_eye_keypoint_ids, get_face_keypoint_ids,
+ get_hand_keypoint_ids, get_mouth_keypoint_ids,
+ get_wrist_keypoint_ids)
+
+__all__ = [
+ 'BufferManager',
+ 'EventManager',
+ 'FrameMessage',
+ 'Message',
+ 'limit_max_fps',
+ 'VideoEndingMessage',
+ 'load_image_from_disk_or_url',
+ 'get_cached_file_path',
+ 'screen_matting',
+ 'expand_and_clamp',
+ 'copy_and_paste',
+ 'is_image_file',
+ 'ImageCapture',
+ 'get_eye_keypoint_ids',
+ 'get_face_keypoint_ids',
+ 'get_wrist_keypoint_ids',
+ 'get_mouth_keypoint_ids',
+ 'get_hand_keypoint_ids',
+]
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/buffer.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9fca4c392703bccb710a9659db21f56ea92e282
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/buffer.py
@@ -0,0 +1,106 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import wraps
+from queue import Queue
+from typing import Dict, List, Optional
+
+from mmcv import is_seq_of
+
+__all__ = ['BufferManager']
+
+
+def check_buffer_registered(exist=True):
+
+ def wrapper(func):
+
+ @wraps(func)
+ def wrapped(manager, name, *args, **kwargs):
+ if exist:
+ # Assert buffer exist
+ if name not in manager:
+ raise ValueError(f'Fail to call {func.__name__}: '
+ f'buffer "{name}" is not registered.')
+ else:
+ # Assert buffer not exist
+ if name in manager:
+ raise ValueError(f'Fail to call {func.__name__}: '
+ f'buffer "{name}" is already registered.')
+ return func(manager, name, *args, **kwargs)
+
+ return wrapped
+
+ return wrapper
+
+
+class Buffer(Queue):
+
+ def put_force(self, item):
+ """Force to put an item into the buffer.
+
+ If the buffer is already full, the earliest item in the buffer will be
+ remove to make room for the incoming item.
+ """
+ with self.mutex:
+ if self.maxsize > 0:
+ while self._qsize() >= self.maxsize:
+ _ = self._get()
+ self.unfinished_tasks -= 1
+
+ self._put(item)
+ self.unfinished_tasks += 1
+ self.not_empty.notify()
+
+
+class BufferManager():
+
+ def __init__(self,
+ buffer_type: type = Buffer,
+ buffers: Optional[Dict] = None):
+ self.buffer_type = buffer_type
+ if buffers is None:
+ self._buffers = {}
+ else:
+ if is_seq_of(list(buffers.values()), buffer_type):
+ self._buffers = buffers.copy()
+ else:
+ raise ValueError('The values of buffers should be instance '
+ f'of {buffer_type}')
+
+ def __contains__(self, name):
+ return name in self._buffers
+
+ @check_buffer_registered(False)
+ def register_buffer(self, name, maxsize=0):
+ self._buffers[name] = self.buffer_type(maxsize)
+
+ @check_buffer_registered()
+ def put(self, name, item, block=True, timeout=None):
+ self._buffers[name].put(item, block, timeout)
+
+ @check_buffer_registered()
+ def put_force(self, name, item):
+ self._buffers[name].put_force(item)
+
+ @check_buffer_registered()
+ def get(self, name, block=True, timeout=None):
+ return self._buffers[name].get(block, timeout)
+
+ @check_buffer_registered()
+ def is_empty(self, name):
+ return self._buffers[name].empty()
+
+ @check_buffer_registered()
+ def is_full(self, name):
+ return self._buffers[name].full()
+
+ def get_sub_manager(self, buffer_names: List[str]):
+ buffers = {name: self._buffers[name] for name in buffer_names}
+ return BufferManager(self.buffer_type, buffers)
+
+ def get_info(self):
+ buffer_info = {}
+ for name, buffer in self._buffers.items():
+ buffer_info[name] = {
+ 'size': buffer.size,
+ 'maxsize': buffer.maxsize
+ }
+ return buffer_info
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/event.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/event.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceab26f72b63d03bc574cda3a713fed67f20f0c0
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/event.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import defaultdict
+from contextlib import contextmanager
+from threading import Event
+from typing import Optional
+
+
+class EventManager():
+
+ def __init__(self):
+ self._events = defaultdict(Event)
+
+ def register_event(self,
+ event_name: str = None,
+ is_keyboard: bool = False):
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ self._events[event_name] = Event()
+
+ def set(self, event_name: str = None, is_keyboard: bool = False):
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ return self._events[event_name].set()
+
+ def wait(self,
+ event_name: str = None,
+ is_keyboard: Optional[bool] = False,
+ timeout: Optional[float] = None):
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ return self._events[event_name].wait(timeout)
+
+ def is_set(self,
+ event_name: str = None,
+ is_keyboard: Optional[bool] = False):
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ return self._events[event_name].is_set()
+
+ def clear(self,
+ event_name: str = None,
+ is_keyboard: Optional[bool] = False):
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ return self._events[event_name].clear()
+
+ @staticmethod
+ def _get_keyboard_event_name(key):
+ return f'_keyboard_{chr(key) if isinstance(key,int) else key}'
+
+ @contextmanager
+ def wait_and_handle(self,
+ event_name: str = None,
+ is_keyboard: Optional[bool] = False):
+ self.wait(event_name, is_keyboard)
+ try:
+ yield
+ finally:
+ self.clear(event_name, is_keyboard)
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/message.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/message.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7b1529c5ece3970dfae189d910720786f32612d
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/message.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+import uuid
+import warnings
+from typing import Dict, List, Optional
+
+import numpy as np
+
+
+class Message():
+ """Message base class.
+
+ All message class should inherit this class. The basic use of a Message
+ instance is to carray a piece of text message (self.msg) and a dict that
+ stores structured data (self.data), e.g. frame image, model prediction,
+ et al.
+
+ A message may also hold route information, which is composed of
+ information of all nodes the message has passed through.
+
+ Parameters:
+ msg (str): The text message.
+ data (dict, optional): The structured data.
+ """
+
+ def __init__(self, msg: str = '', data: Optional[Dict] = None):
+ self.msg = msg
+ self.data = data if data else {}
+ self.route_info = []
+ self.timestamp = time.time()
+ self.id = uuid.uuid4()
+
+ def update_route_info(self,
+ node=None,
+ node_name: Optional[str] = None,
+ node_type: Optional[str] = None,
+ info: Optional[Dict] = None):
+ """Append new node information to the route information.
+
+ Args:
+ node (Node, optional): An instance of Node that provides basic
+ information like the node name and type. Default: None.
+ node_name (str, optional): The node name. If node is given,
+ node_name will be ignored. Default: None.
+ node_type (str, optional): The class name of the node. If node
+ is given, node_type will be ignored. Default: None.
+ info (dict, optional): The node information, which is usually
+ given by node.get_node_info(). Default: None.
+ """
+ if node is not None:
+ if node_name is not None or node_type is not None:
+ warnings.warn(
+ '`node_name` and `node_type` will be overridden if node'
+ 'is provided.')
+ node_name = node.name
+ node_type = node.__class__.__name__
+
+ node_info = {'node': node_name, 'node_type': node_type, 'info': info}
+ self.route_info.append(node_info)
+
+ def set_route_info(self, route_info: List):
+ """Directly set the entire route information.
+
+ Args:
+ route_info (list): route information to set to the message.
+ """
+ self.route_info = route_info
+
+ def merge_route_info(self, route_info: List):
+ """Merge the given route information into the original one of the
+ message. This is used for combining route information from multiple
+ messages. The node information in the route will be reordered according
+ to their timestamps.
+
+ Args:
+ route_info (list): route information to merge.
+ """
+ self.route_info += route_info
+ self.route_info.sort(key=lambda x: x.get('timestamp', np.inf))
+
+ def get_route_info(self) -> List:
+ return self.route_info.copy()
+
+
+class VideoEndingMessage(Message):
+ """A special message to indicate the input video is ending."""
+
+
+class FrameMessage(Message):
+ """The message to store information of a video frame.
+
+ A FrameMessage instance usually holds following data in self.data:
+ - image (array): The frame image
+ - detection_results (list): A list to hold detection results of
+ multiple detectors. Each element is a tuple (tag, result)
+ - pose_results (list): A list to hold pose estimation results of
+ multiple pose estimator. Each element is a tuple (tag, result)
+ """
+
+ def __init__(self, img):
+ super().__init__(data=dict(image=img))
+
+ def get_image(self):
+ """Get the frame image.
+
+ Returns:
+ array: The frame image.
+ """
+ return self.data.get('image', None)
+
+ def set_image(self, img):
+ """Set the frame image to the message."""
+ self.data['image'] = img
+
+ def add_detection_result(self, result, tag: str = None):
+ """Add the detection result from one model into the message's
+ detection_results.
+
+ Args:
+ tag (str, optional): Give a tag to the result, which can be used
+ to retrieve specific results.
+ """
+ if 'detection_results' not in self.data:
+ self.data['detection_results'] = []
+ self.data['detection_results'].append((tag, result))
+
+ def get_detection_results(self, tag: str = None):
+ """Get detection results of the message.
+
+ Args:
+ tag (str, optional): If given, only the results with the tag
+ will be retrieved. Otherwise all results will be retrieved.
+ Default: None.
+
+ Returns:
+ list[dict]: The retrieved detection results
+ """
+ if 'detection_results' not in self.data:
+ return None
+ if tag is None:
+ results = [res for _, res in self.data['detection_results']]
+ else:
+ results = [
+ res for _tag, res in self.data['detection_results']
+ if _tag == tag
+ ]
+ return results
+
+ def add_pose_result(self, result, tag=None):
+ """Add the pose estimation result from one model into the message's
+ pose_results.
+
+ Args:
+ tag (str, optional): Give a tag to the result, which can be used
+ to retrieve specific results.
+ """
+ if 'pose_results' not in self.data:
+ self.data['pose_results'] = []
+ self.data['pose_results'].append((tag, result))
+
+ def get_pose_results(self, tag=None):
+ """Get pose estimation results of the message.
+
+ Args:
+ tag (str, optional): If given, only the results with the tag
+ will be retrieved. Otherwise all results will be retrieved.
+ Default: None.
+
+ Returns:
+ list[dict]: The retrieved pose results
+ """
+ if 'pose_results' not in self.data:
+ return None
+ if tag is None:
+ results = [res for _, res in self.data['pose_results']]
+ else:
+ results = [
+ res for _tag, res in self.data['pose_results'] if _tag == tag
+ ]
+ return results
+
+ def get_full_results(self):
+ """Get all model predictions of the message.
+
+ See set_full_results() for inference.
+
+ Returns:
+ dict: All model predictions, including:
+ - detection_results
+ - pose_results
+ """
+ result_keys = ['detection_results', 'pose_results']
+ results = {k: self.data[k] for k in result_keys}
+ return results
+
+ def set_full_results(self, results):
+ """Set full model results directly.
+
+ Args:
+ results (dict): All model predictions including:
+ - detection_results (list): see also add_detection_results()
+ - pose_results (list): see also add_pose_results()
+ """
+ self.data.update(results)
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/misc.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..c64f4179db8a3618b38e3d6933992e9b3294af55
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/misc.py
@@ -0,0 +1,343 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import sys
+import time
+from contextlib import contextmanager
+from typing import Optional
+from urllib.parse import urlparse
+from urllib.request import urlopen
+
+import cv2
+import numpy as np
+from torch.hub import HASH_REGEX, download_url_to_file
+
+
+@contextmanager
+def limit_max_fps(fps: Optional[float]):
+ t_start = time.time()
+ try:
+ yield
+ finally:
+ t_end = time.time()
+ if fps is not None:
+ t_sleep = 1.0 / fps - t_end + t_start
+ if t_sleep > 0:
+ time.sleep(t_sleep)
+
+
+def _is_url(filename):
+ """Check if the file is a url link.
+
+ Args:
+ filename (str): the file name or url link.
+
+ Returns:
+ bool: is url or not.
+ """
+ prefixes = ['http://', 'https://']
+ for p in prefixes:
+ if filename.startswith(p):
+ return True
+ return False
+
+
+def load_image_from_disk_or_url(filename, readFlag=cv2.IMREAD_COLOR):
+ """Load an image file, from disk or url.
+
+ Args:
+ filename (str): file name on the disk or url link.
+ readFlag (int): readFlag for imdecode.
+
+ Returns:
+ np.ndarray: A loaded image
+ """
+ if _is_url(filename):
+ # download the image, convert it to a NumPy array, and then read
+ # it into OpenCV format
+ resp = urlopen(filename)
+ image = np.asarray(bytearray(resp.read()), dtype='uint8')
+ image = cv2.imdecode(image, readFlag)
+ return image
+ else:
+ image = cv2.imread(filename, readFlag)
+ return image
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+ if dir_name == '':
+ return
+ dir_name = osp.expanduser(dir_name)
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
+
+
+def get_cached_file_path(url,
+ save_dir=None,
+ progress=True,
+ check_hash=False,
+ file_name=None):
+ r"""Loads the Torch serialized object at the given URL.
+
+ If downloaded file is a zip file, it will be automatically decompressed
+
+ If the object is already present in `model_dir`, it's deserialized and
+ returned.
+ The default value of ``model_dir`` is ``/checkpoints`` where
+ ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
+
+ Args:
+ url (str): URL of the object to download
+ save_dir (str, optional): directory in which to save the object
+ progress (bool, optional): whether or not to display a progress bar
+ to stderr. Default: True
+ check_hash(bool, optional): If True, the filename part of the URL
+ should follow the naming convention ``filename-.ext``
+ where ```` is the first eight or more digits of the
+ SHA256 hash of the contents of the file. The hash is used to
+ ensure unique names and to verify the contents of the file.
+ Default: False
+ file_name (str, optional): name for the downloaded file. Filename
+ from ``url`` will be used if not set. Default: None.
+ """
+ if save_dir is None:
+ save_dir = os.path.join('webcam_resources')
+
+ mkdir_or_exist(save_dir)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.join(save_dir, filename)
+ if not os.path.exists(cached_file):
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
+ hash_prefix = None
+ if check_hash:
+ r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
+ hash_prefix = r.group(1) if r else None
+ download_url_to_file(url, cached_file, hash_prefix, progress=progress)
+ return cached_file
+
+
+def screen_matting(img, color_low=None, color_high=None, color=None):
+ """Screen Matting.
+
+ Args:
+ img (np.ndarray): Image data.
+ color_low (tuple): Lower limit (b, g, r).
+ color_high (tuple): Higher limit (b, g, r).
+ color (str): Support colors include:
+
+ - 'green' or 'g'
+ - 'blue' or 'b'
+ - 'black' or 'k'
+ - 'white' or 'w'
+ """
+
+ if color_high is None or color_low is None:
+ if color is not None:
+ if color.lower() == 'g' or color.lower() == 'green':
+ color_low = (0, 200, 0)
+ color_high = (60, 255, 60)
+ elif color.lower() == 'b' or color.lower() == 'blue':
+ color_low = (230, 0, 0)
+ color_high = (255, 40, 40)
+ elif color.lower() == 'k' or color.lower() == 'black':
+ color_low = (0, 0, 0)
+ color_high = (40, 40, 40)
+ elif color.lower() == 'w' or color.lower() == 'white':
+ color_low = (230, 230, 230)
+ color_high = (255, 255, 255)
+ else:
+ NotImplementedError(f'Not supported color: {color}.')
+ else:
+ ValueError('color or color_high | color_low should be given.')
+
+ mask = cv2.inRange(img, np.array(color_low), np.array(color_high)) == 0
+
+ return mask.astype(np.uint8)
+
+
+def expand_and_clamp(box, im_shape, s=1.25):
+ """Expand the bbox and clip it to fit the image shape.
+
+ Args:
+ box (list): x1, y1, x2, y2
+ im_shape (ndarray): image shape (h, w, c)
+ s (float): expand ratio
+
+ Returns:
+ list: x1, y1, x2, y2
+ """
+
+ x1, y1, x2, y2 = box[:4]
+ w = x2 - x1
+ h = y2 - y1
+ deta_w = w * (s - 1) / 2
+ deta_h = h * (s - 1) / 2
+
+ x1, y1, x2, y2 = x1 - deta_w, y1 - deta_h, x2 + deta_w, y2 + deta_h
+
+ img_h, img_w = im_shape[:2]
+
+ x1 = min(max(0, int(x1)), img_w - 1)
+ y1 = min(max(0, int(y1)), img_h - 1)
+ x2 = min(max(0, int(x2)), img_w - 1)
+ y2 = min(max(0, int(y2)), img_h - 1)
+
+ return [x1, y1, x2, y2]
+
+
+def _find_connected_components(mask):
+ """Find connected components and sort with areas.
+
+ Args:
+ mask (ndarray): instance segmentation result.
+
+ Returns:
+ ndarray (N, 5): Each item contains (x, y, w, h, area).
+ """
+ num, labels, stats, centroids = cv2.connectedComponentsWithStats(mask)
+ stats = stats[stats[:, 4].argsort()]
+ return stats
+
+
+def _find_bbox(mask):
+ """Find the bounding box for the mask.
+
+ Args:
+ mask (ndarray): Mask.
+
+ Returns:
+ list(4, ): Returned box (x1, y1, x2, y2).
+ """
+ mask_shape = mask.shape
+ if len(mask_shape) == 3:
+ assert mask_shape[-1] == 1, 'the channel of the mask should be 1.'
+ elif len(mask_shape) == 2:
+ pass
+ else:
+ NotImplementedError()
+
+ h, w = mask_shape[:2]
+ mask_w = mask.sum(0)
+ mask_h = mask.sum(1)
+
+ left = 0
+ right = w - 1
+ up = 0
+ down = h - 1
+
+ for i in range(w):
+ if mask_w[i] > 0:
+ break
+ left += 1
+
+ for i in range(w - 1, left, -1):
+ if mask_w[i] > 0:
+ break
+ right -= 1
+
+ for i in range(h):
+ if mask_h[i] > 0:
+ break
+ up += 1
+
+ for i in range(h - 1, up, -1):
+ if mask_h[i] > 0:
+ break
+ down -= 1
+
+ return [left, up, right, down]
+
+
+def copy_and_paste(img,
+ background_img,
+ mask,
+ bbox=None,
+ effect_region=(0.2, 0.2, 0.8, 0.8),
+ min_size=(20, 20)):
+ """Copy the image region and paste to the background.
+
+ Args:
+ img (np.ndarray): Image data.
+ background_img (np.ndarray): Background image data.
+ mask (ndarray): instance segmentation result.
+ bbox (ndarray): instance bbox, (x1, y1, x2, y2).
+ effect_region (tuple(4, )): The region to apply mask, the coordinates
+ are normalized (x1, y1, x2, y2).
+ """
+ background_img = background_img.copy()
+ background_h, background_w = background_img.shape[:2]
+ region_h = (effect_region[3] - effect_region[1]) * background_h
+ region_w = (effect_region[2] - effect_region[0]) * background_w
+ region_aspect_ratio = region_w / region_h
+
+ if bbox is None:
+ bbox = _find_bbox(mask)
+ instance_w = bbox[2] - bbox[0]
+ instance_h = bbox[3] - bbox[1]
+
+ if instance_w > min_size[0] and instance_h > min_size[1]:
+ aspect_ratio = instance_w / instance_h
+ if region_aspect_ratio > aspect_ratio:
+ resize_rate = region_h / instance_h
+ else:
+ resize_rate = region_w / instance_w
+
+ mask_inst = mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]
+ img_inst = img[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]
+ img_inst = cv2.resize(img_inst, (int(
+ resize_rate * instance_w), int(resize_rate * instance_h)))
+ mask_inst = cv2.resize(
+ mask_inst,
+ (int(resize_rate * instance_w), int(resize_rate * instance_h)),
+ interpolation=cv2.INTER_NEAREST)
+
+ mask_ids = list(np.where(mask_inst == 1))
+ mask_ids[1] += int(effect_region[0] * background_w)
+ mask_ids[0] += int(effect_region[1] * background_h)
+
+ background_img[tuple(mask_ids)] = img_inst[np.where(mask_inst == 1)]
+
+ return background_img
+
+
+def is_image_file(path):
+ if isinstance(path, str):
+ if path.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp')):
+ return True
+ return False
+
+
+class ImageCapture:
+ """A mock-up version of cv2.VideoCapture that always return a const image.
+
+ Args:
+ image (str | ndarray): The image or image path
+ """
+
+ def __init__(self, image):
+ if isinstance(image, str):
+ self.image = load_image_from_disk_or_url(image)
+ else:
+ self.image = image
+
+ def isOpened(self):
+ return (self.image is not None)
+
+ def read(self):
+ return True, self.image.copy()
+
+ def release(self):
+ pass
+
+ def get(self, propId):
+ if propId == cv2.CAP_PROP_FRAME_WIDTH:
+ return self.image.shape[1]
+ elif propId == cv2.CAP_PROP_FRAME_HEIGHT:
+ return self.image.shape[0]
+ elif propId == cv2.CAP_PROP_FPS:
+ return np.nan
+ else:
+ raise NotImplementedError()
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/pose.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..196b40ef53d78173742d4d6f953176cf76238308
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/utils/pose.py
@@ -0,0 +1,226 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Tuple
+
+from mmcv import Config
+
+from mmpose.datasets.dataset_info import DatasetInfo
+
+
+def get_eye_keypoint_ids(model_cfg: Config) -> Tuple[int, int]:
+ """A helpfer function to get the keypoint indices of left and right eyes
+ from the model config.
+
+ Args:
+ model_cfg (Config): pose model config.
+
+ Returns:
+ int: left eye keypoint index.
+ int: right eye keypoint index.
+ """
+ left_eye_idx = None
+ right_eye_idx = None
+
+ # try obtaining eye point ids from dataset_info
+ try:
+ dataset_info = DatasetInfo(model_cfg.data.test.dataset_info)
+ left_eye_idx = dataset_info.keypoint_name2id.get('left_eye', None)
+ right_eye_idx = dataset_info.keypoint_name2id.get('right_eye', None)
+ except AttributeError:
+ left_eye_idx = None
+ right_eye_idx = None
+
+ if left_eye_idx is None or right_eye_idx is None:
+ # Fall back to hard coded keypoint id
+ dataset_name = model_cfg.data.test.type
+ if dataset_name in {
+ 'TopDownCocoDataset', 'TopDownCocoWholeBodyDataset'
+ }:
+ left_eye_idx = 1
+ right_eye_idx = 2
+ elif dataset_name in {'AnimalPoseDataset', 'AnimalAP10KDataset'}:
+ left_eye_idx = 0
+ right_eye_idx = 1
+ else:
+ raise ValueError('Can not determine the eye keypoint id of '
+ f'{dataset_name}')
+
+ return left_eye_idx, right_eye_idx
+
+
+def get_face_keypoint_ids(model_cfg: Config) -> Tuple[int, int]:
+ """A helpfer function to get the keypoint indices of the face from the
+ model config.
+
+ Args:
+ model_cfg (Config): pose model config.
+
+ Returns:
+ list[int]: face keypoint index.
+ """
+ face_indices = None
+
+ # try obtaining nose point ids from dataset_info
+ try:
+ dataset_info = DatasetInfo(model_cfg.data.test.dataset_info)
+ for id in range(68):
+ face_indices.append(
+ dataset_info.keypoint_name2id.get(f'face_{id}', None))
+ except AttributeError:
+ face_indices = None
+
+ if face_indices is None:
+ # Fall back to hard coded keypoint id
+ dataset_name = model_cfg.data.test.type
+ if dataset_name in {'TopDownCocoWholeBodyDataset'}:
+ face_indices = list(range(23, 91))
+ else:
+ raise ValueError('Can not determine the face id of '
+ f'{dataset_name}')
+
+ return face_indices
+
+
+def get_wrist_keypoint_ids(model_cfg: Config) -> Tuple[int, int]:
+ """A helpfer function to get the keypoint indices of left and right wrist
+ from the model config.
+
+ Args:
+ model_cfg (Config): pose model config.
+ Returns:
+ int: left wrist keypoint index.
+ int: right wrist keypoint index.
+ """
+
+ # try obtaining eye point ids from dataset_info
+ try:
+ dataset_info = DatasetInfo(model_cfg.data.test.dataset_info)
+ left_wrist_idx = dataset_info.keypoint_name2id.get('left_wrist', None)
+ right_wrist_idx = dataset_info.keypoint_name2id.get(
+ 'right_wrist', None)
+ except AttributeError:
+ left_wrist_idx = None
+ right_wrist_idx = None
+
+ if left_wrist_idx is None or right_wrist_idx is None:
+ # Fall back to hard coded keypoint id
+ dataset_name = model_cfg.data.test.type
+ if dataset_name in {
+ 'TopDownCocoDataset', 'TopDownCocoWholeBodyDataset'
+ }:
+ left_wrist_idx = 9
+ right_wrist_idx = 10
+ elif dataset_name == 'AnimalPoseDataset':
+ left_wrist_idx = 16
+ right_wrist_idx = 17
+ elif dataset_name == 'AnimalAP10KDataset':
+ left_wrist_idx = 7
+ right_wrist_idx = 10
+ else:
+ raise ValueError('Can not determine the eye keypoint id of '
+ f'{dataset_name}')
+
+ return left_wrist_idx, right_wrist_idx
+
+
+def get_mouth_keypoint_ids(model_cfg: Config) -> Tuple[int, int]:
+ """A helpfer function to get the keypoint indices of the left and right
+ part of mouth from the model config.
+
+ Args:
+ model_cfg (Config): pose model config.
+ Returns:
+ int: left-part mouth keypoint index.
+ int: right-part mouth keypoint index.
+ """
+ # try obtaining mouth point ids from dataset_info
+ try:
+ dataset_info = DatasetInfo(model_cfg.data.test.dataset_info)
+ mouth_index = dataset_info.keypoint_name2id.get('face-62', None)
+ except AttributeError:
+ mouth_index = None
+
+ if mouth_index is None:
+ # Fall back to hard coded keypoint id
+ dataset_name = model_cfg.data.test.type
+ if dataset_name == 'TopDownCocoWholeBodyDataset':
+ mouth_index = 85
+ else:
+ raise ValueError('Can not determine the eye keypoint id of '
+ f'{dataset_name}')
+
+ return mouth_index
+
+
+def get_hand_keypoint_ids(model_cfg: Config) -> List[int]:
+ """A helpfer function to get the keypoint indices of left and right hand
+ from the model config.
+
+ Args:
+ model_cfg (Config): pose model config.
+ Returns:
+ list[int]: hand keypoint indices.
+ """
+ # try obtaining hand keypoint ids from dataset_info
+ try:
+ hand_indices = []
+ dataset_info = DatasetInfo(model_cfg.data.test.dataset_info)
+
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get('left_hand_root', None))
+
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'left_thumb{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'left_forefinger{id}',
+ None))
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'left_middle_finger{id}',
+ None))
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'left_ring_finger{id}',
+ None))
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'left_pinky_finger{id}',
+ None))
+
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get('right_hand_root', None))
+
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'right_thumb{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'right_forefinger{id}',
+ None))
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'right_middle_finger{id}',
+ None))
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'right_ring_finger{id}',
+ None))
+ for id in range(1, 5):
+ hand_indices.append(
+ dataset_info.keypoint_name2id.get(f'right_pinky_finger{id}',
+ None))
+
+ except AttributeError:
+ hand_indices = None
+
+ if hand_indices is None:
+ # Fall back to hard coded keypoint id
+ dataset_name = model_cfg.data.test.type
+ if dataset_name in {'TopDownCocoWholeBodyDataset'}:
+ hand_indices = list(range(91, 133))
+ else:
+ raise ValueError('Can not determine the hand id of '
+ f'{dataset_name}')
+
+ return hand_indices
diff --git a/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/webcam_runner.py b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/webcam_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..7843b392cfd367d778109794a345f1c361395407
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/third-party/ViTPose/tools/webcam/webcam_apis/webcam_runner.py
@@ -0,0 +1,272 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import sys
+import time
+import warnings
+from contextlib import nullcontext
+from threading import Thread
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+
+from .nodes import NODES
+from .utils import (BufferManager, EventManager, FrameMessage, ImageCapture,
+ VideoEndingMessage, is_image_file, limit_max_fps)
+
+DEFAULT_FRAME_BUFFER_SIZE = 1
+DEFAULT_INPUT_BUFFER_SIZE = 1
+DEFAULT_DISPLAY_BUFFER_SIZE = 0
+DEFAULT_USER_BUFFER_SIZE = 1
+
+
+class WebcamRunner():
+ """An interface for building webcam application from config.
+
+ Parameters:
+ name (str): Runner name.
+ camera_id (int | str): The camera ID (usually the ID of the default
+ camera is 0). Alternatively a file path or a URL can be given
+ to load from a video or image file.
+ camera_frame_shape (tuple, optional): Set the frame shape of the
+ camera in (width, height). If not given, the default frame shape
+ will be used. This argument is only valid when using a camera
+ as the input source. Default: None
+ camera_fps (int): Video reading maximum FPS. Default: 30
+ buffer_sizes (dict, optional): A dict to specify buffer sizes. The
+ key is the buffer name and the value is the buffer size.
+ Default: None
+ nodes (list): Node configs.
+ """
+
+ def __init__(self,
+ name: str = 'Default Webcam Runner',
+ camera_id: Union[int, str] = 0,
+ camera_fps: int = 30,
+ camera_frame_shape: Optional[Tuple[int, int]] = None,
+ synchronous: bool = False,
+ buffer_sizes: Optional[Dict[str, int]] = None,
+ nodes: Optional[List[Dict]] = None):
+
+ # Basic parameters
+ self.name = name
+ self.camera_id = camera_id
+ self.camera_fps = camera_fps
+ self.camera_frame_shape = camera_frame_shape
+ self.synchronous = synchronous
+
+ # self.buffer_manager manages data flow between runner and nodes
+ self.buffer_manager = BufferManager()
+ # self.event_manager manages event-based asynchronous communication
+ self.event_manager = EventManager()
+ # self.node_list holds all node instance
+ self.node_list = []
+ # self.vcap is used to read camera frames. It will be built when the
+ # runner starts running
+ self.vcap = None
+
+ # Register runner events
+ self.event_manager.register_event('_exit_', is_keyboard=False)
+ if self.synchronous:
+ self.event_manager.register_event('_idle_', is_keyboard=False)
+
+ # Register nodes
+ if not nodes:
+ raise ValueError('No node is registered to the runner.')
+
+ # Register default buffers
+ if buffer_sizes is None:
+ buffer_sizes = {}
+ # _frame_ buffer
+ frame_buffer_size = buffer_sizes.get('_frame_',
+ DEFAULT_FRAME_BUFFER_SIZE)
+ self.buffer_manager.register_buffer('_frame_', frame_buffer_size)
+ # _input_ buffer
+ input_buffer_size = buffer_sizes.get('_input_',
+ DEFAULT_INPUT_BUFFER_SIZE)
+ self.buffer_manager.register_buffer('_input_', input_buffer_size)
+ # _display_ buffer
+ display_buffer_size = buffer_sizes.get('_display_',
+ DEFAULT_DISPLAY_BUFFER_SIZE)
+ self.buffer_manager.register_buffer('_display_', display_buffer_size)
+
+ # Build all nodes:
+ for node_cfg in nodes:
+ logging.info(f'Create node: {node_cfg.name}({node_cfg.type})')
+ node = NODES.build(node_cfg)
+
+ # Register node
+ self.node_list.append(node)
+
+ # Register buffers
+ for buffer_info in node.registered_buffers:
+ buffer_name = buffer_info.buffer_name
+ if buffer_name in self.buffer_manager:
+ continue
+ buffer_size = buffer_sizes.get(buffer_name,
+ DEFAULT_USER_BUFFER_SIZE)
+ self.buffer_manager.register_buffer(buffer_name, buffer_size)
+ logging.info(
+ f'Register user buffer: {buffer_name}({buffer_size})')
+
+ # Register events
+ for event_info in node.registered_events:
+ self.event_manager.register_event(
+ event_name=event_info.event_name,
+ is_keyboard=event_info.is_keyboard)
+ logging.info(f'Register event: {event_info.event_name}')
+
+ # Set runner for nodes
+ # This step is performed after node building when the runner has
+ # create full buffer/event managers and can
+ for node in self.node_list:
+ logging.info(f'Set runner for node: {node.name})')
+ node.set_runner(self)
+
+ def _read_camera(self):
+ """Continually read video frames and put them into buffers."""
+
+ camera_id = self.camera_id
+ fps = self.camera_fps
+
+ # Build video capture
+ if is_image_file(camera_id):
+ self.vcap = ImageCapture(camera_id)
+ else:
+ self.vcap = cv2.VideoCapture(camera_id)
+ if self.camera_frame_shape is not None:
+ width, height = self.camera_frame_shape
+ self.vcap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
+ self.vcap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
+
+ if not self.vcap.isOpened():
+ warnings.warn(f'Cannot open camera (ID={camera_id})')
+ sys.exit()
+
+ # Read video frames in a loop
+ first_frame = True
+ while not self.event_manager.is_set('_exit_'):
+ if self.synchronous:
+ if first_frame:
+ cm = nullcontext()
+ else:
+ # Read a new frame until the last frame has been processed
+ cm = self.event_manager.wait_and_handle('_idle_')
+ else:
+ # Read frames with a maximum FPS
+ cm = limit_max_fps(fps)
+
+ first_frame = False
+
+ with cm:
+ # Read a frame
+ ret_val, frame = self.vcap.read()
+ if ret_val:
+ # Put frame message (for display) into buffer `_frame_`
+ frame_msg = FrameMessage(frame)
+ self.buffer_manager.put('_frame_', frame_msg)
+
+ # Put input message (for model inference or other use)
+ # into buffer `_input_`
+ input_msg = FrameMessage(frame.copy())
+ input_msg.update_route_info(
+ node_name='Camera Info',
+ node_type='dummy',
+ info=self._get_camera_info())
+ self.buffer_manager.put_force('_input_', input_msg)
+
+ else:
+ # Put a video ending signal
+ self.buffer_manager.put('_frame_', VideoEndingMessage())
+
+ self.vcap.release()
+
+ def _display(self):
+ """Continually obtain and display output frames."""
+
+ output_msg = None
+
+ while not self.event_manager.is_set('_exit_'):
+ while self.buffer_manager.is_empty('_display_'):
+ time.sleep(0.001)
+
+ # Set _idle_ to allow reading next frame
+ if self.synchronous:
+ self.event_manager.set('_idle_')
+
+ # acquire output from buffer
+ output_msg = self.buffer_manager.get('_display_')
+
+ # None indicates input stream ends
+ if isinstance(output_msg, VideoEndingMessage):
+ self.event_manager.set('_exit_')
+ break
+
+ img = output_msg.get_image()
+
+ # show in a window
+ cv2.imshow(self.name, img)
+
+ # handle keyboard input
+ key = cv2.waitKey(1)
+ if key != -1:
+ self._on_keyboard_input(key)
+
+ cv2.destroyAllWindows()
+
+ def _on_keyboard_input(self, key):
+ """Handle the keyboard input."""
+
+ if key in (27, ord('q'), ord('Q')):
+ logging.info(f'Exit event captured: {key}')
+ self.event_manager.set('_exit_')
+ else:
+ logging.info(f'Keyboard event captured: {key}')
+ self.event_manager.set(key, is_keyboard=True)
+
+ def _get_camera_info(self):
+ """Return the camera information in a dict."""
+
+ frame_width = self.vcap.get(cv2.CAP_PROP_FRAME_WIDTH)
+ frame_height = self.vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)
+ frame_rate = self.vcap.get(cv2.CAP_PROP_FPS)
+
+ cam_info = {
+ 'Camera ID': self.camera_id,
+ 'Source resolution': f'{frame_width}x{frame_height}',
+ 'Source FPS': frame_rate,
+ }
+
+ return cam_info
+
+ def run(self):
+ """Program entry.
+
+ This method starts all nodes as well as video I/O in separate threads.
+ """
+
+ try:
+ # Start node threads
+ non_daemon_nodes = []
+ for node in self.node_list:
+ node.start()
+ if not node.daemon:
+ non_daemon_nodes.append(node)
+
+ # Create a thread to read video frames
+ t_read = Thread(target=self._read_camera, args=())
+ t_read.start()
+
+ # Run display in the main thread
+ self._display()
+ logging.info('Display shut down')
+
+ # joint non-daemon nodes and runner threads
+ logging.info('Camera reading about to join')
+ t_read.join()
+
+ for node in non_daemon_nodes:
+ logging.info(f'Node {node.name} about to join')
+ node.join()
+
+ except KeyboardInterrupt:
+ pass
diff --git a/phantom/submodules/phantom-hamer/train.py b/phantom/submodules/phantom-hamer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..329e40bb19f3fed4ba42fd9fff1abaefa22ff287
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/train.py
@@ -0,0 +1,113 @@
+from typing import Optional, Tuple
+import pyrootutils
+
+root = pyrootutils.setup_root(
+ search_from=__file__,
+ indicator=[".git", "pyproject.toml"],
+ pythonpath=True,
+ dotenv=True,
+)
+
+import os
+from pathlib import Path
+
+import hydra
+import pytorch_lightning as pl
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Trainer
+from pytorch_lightning.loggers import TensorBoardLogger
+from pytorch_lightning.plugins.environments import SLURMEnvironment
+#from pytorch_lightning.trainingtype import DDPPlugin
+
+from yacs.config import CfgNode
+from hamer.configs import dataset_config
+from hamer.datasets import HAMERDataModule
+from hamer.models.hamer import HAMER
+from hamer.utils.pylogger import get_pylogger
+from hamer.utils.misc import task_wrapper, log_hyperparameters
+
+# HACK reset the signal handling so the lightning is free to set it
+# Based on https://github.com/facebookincubator/submitit/issues/1709#issuecomment-1246758283
+import signal
+signal.signal(signal.SIGUSR1, signal.SIG_DFL)
+
+log = get_pylogger(__name__)
+
+
+@pl.utilities.rank_zero.rank_zero_only
+def save_configs(model_cfg: CfgNode, dataset_cfg: CfgNode, rootdir: str):
+ """Save config files to rootdir."""
+ Path(rootdir).mkdir(parents=True, exist_ok=True)
+ OmegaConf.save(config=model_cfg, f=os.path.join(rootdir, 'model_config.yaml'))
+ with open(os.path.join(rootdir, 'dataset_config.yaml'), 'w') as f:
+ f.write(dataset_cfg.dump())
+
+@task_wrapper
+def train(cfg: DictConfig) -> Tuple[dict, dict]:
+
+ # Load dataset config
+ dataset_cfg = dataset_config()
+
+ # Save configs
+ save_configs(cfg, dataset_cfg, cfg.paths.output_dir)
+
+ # Setup training and validation datasets
+ datamodule = HAMERDataModule(cfg, dataset_cfg)
+
+ # Setup model
+ model = HAMER(cfg)
+
+ # Setup Tensorboard logger
+ logger = TensorBoardLogger(os.path.join(cfg.paths.output_dir, 'tensorboard'), name='', version='', default_hp_metric=False)
+ loggers = [logger]
+
+ # Setup checkpoint saving
+ checkpoint_callback = pl.callbacks.ModelCheckpoint(
+ dirpath=os.path.join(cfg.paths.output_dir, 'checkpoints'),
+ every_n_train_steps=cfg.GENERAL.CHECKPOINT_STEPS,
+ save_last=True,
+ save_top_k=cfg.GENERAL.CHECKPOINT_SAVE_TOP_K,
+ )
+ rich_callback = pl.callbacks.RichProgressBar()
+ lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
+ callbacks = [
+ checkpoint_callback,
+ lr_monitor,
+ # rich_callback
+ ]
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(
+ cfg.trainer,
+ callbacks=callbacks,
+ logger=loggers,
+ #plugins=(SLURMEnvironment(requeue_signal=signal.SIGUSR2) if (cfg.get('launcher',None) is not None) else DDPPlugin(find_unused_parameters=False)), # Submitit uses SIGUSR2
+ plugins=(SLURMEnvironment(requeue_signal=signal.SIGUSR2) if (cfg.get('launcher',None) is not None) else None), # Submitit uses SIGUSR2
+ )
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ log_hyperparameters(object_dict)
+
+ # Train the model
+ trainer.fit(model, datamodule=datamodule, ckpt_path='last')
+ log.info("Fitting done")
+
+
+@hydra.main(version_base="1.2", config_path=str(root/"hamer/configs_hydra"), config_name="train.yaml")
+def main(cfg: DictConfig) -> Optional[float]:
+ # train the model
+ train(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/phantom/submodules/phantom-hamer/vitpose_model.py b/phantom/submodules/phantom-hamer/vitpose_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9206f38420220234baca89f2a3438b2ca9ce51d
--- /dev/null
+++ b/phantom/submodules/phantom-hamer/vitpose_model.py
@@ -0,0 +1,86 @@
+from __future__ import annotations
+
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from mmpose.apis import inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result
+
+os.environ["PYOPENGL_PLATFORM"] = "egl"
+
+# project root directory
+ROOT_DIR = "./"
+VIT_DIR = os.path.join(ROOT_DIR, "third-party/ViTPose")
+
+class ViTPoseModel(object):
+ def __init__(self, device: str | torch.device, root_dir: str = ROOT_DIR, vit_dir: str = VIT_DIR):
+ self.MODEL_DICT = {
+ 'ViTPose+-G (multi-task train, COCO)': {
+ 'config': f'{vit_dir}/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/ViTPose_huge_wholebody_256x192.py',
+ 'model': f'{root_dir}/_DATA/vitpose_ckpts/vitpose+_huge/wholebody.pth',
+ },
+ }
+ self.device = torch.device(device)
+ self.model_name = 'ViTPose+-G (multi-task train, COCO)'
+ self.model = self._load_model(self.model_name)
+
+ def _load_all_models_once(self) -> None:
+ for name in self.MODEL_DICT:
+ self._load_model(name)
+
+ def _load_model(self, name: str) -> nn.Module:
+ dic = self.MODEL_DICT[name]
+ ckpt_path = dic['model']
+ model = init_pose_model(dic['config'], ckpt_path, device=self.device)
+ return model
+
+ def set_model(self, name: str) -> None:
+ if name == self.model_name:
+ return
+ self.model_name = name
+ self.model = self._load_model(name)
+
+ def predict_pose_and_visualize(
+ self,
+ image: np.ndarray,
+ det_results: list[np.ndarray],
+ box_score_threshold: float,
+ kpt_score_threshold: float,
+ vis_dot_radius: int,
+ vis_line_thickness: int,
+ ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
+ out = self.predict_pose(image, det_results, box_score_threshold)
+ vis = self.visualize_pose_results(image, out, kpt_score_threshold,
+ vis_dot_radius, vis_line_thickness)
+ return out, vis
+
+ def predict_pose(
+ self,
+ image: np.ndarray,
+ det_results: list[np.ndarray],
+ box_score_threshold: float = 0.5) -> list[dict[str, np.ndarray]]:
+ image = image[:, :, ::-1] # RGB -> BGR
+ person_results = process_mmdet_results(det_results, 1)
+ out, _ = inference_top_down_pose_model(self.model,
+ image,
+ person_results=person_results,
+ bbox_thr=box_score_threshold,
+ format='xyxy')
+ return out
+
+ def visualize_pose_results(self,
+ image: np.ndarray,
+ pose_results: list[np.ndarray],
+ kpt_score_threshold: float = 0.3,
+ vis_dot_radius: int = 4,
+ vis_line_thickness: int = 1) -> np.ndarray:
+ image = image[:, :, ::-1] # RGB -> BGR
+ vis = vis_pose_result(self.model,
+ image,
+ pose_results,
+ kpt_score_thr=kpt_score_threshold,
+ radius=vis_dot_radius,
+ thickness=vis_line_thickness)
+ return vis[:, :, ::-1] # BGR -> RGB
diff --git a/phantom/submodules/phantom-robomimic/.gitignore b/phantom/submodules/phantom-robomimic/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..19002b4a83137bcec04430bbb303436f4ea15b6a
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/.gitignore
@@ -0,0 +1,126 @@
+# pip distribution folder
+dist/
+
+# datasets folder at top-level (leading slash)
+/datasets
+/experiment_results
+
+# local test dataset that is lazily downloaded by example scripts
+tests/assets/test.hdf5
+tests/assets/test_v141.hdf5
+
+# pycharm configs
+.idea/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+
+.DS_Store
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+*.mp4
+*.pth
+
+# private macros
+macros_private.py
diff --git a/phantom/submodules/phantom-robomimic/LICENSE b/phantom/submodules/phantom-robomimic/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..934eaa87bb98d79ced50c8f27849625dc97b934d
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Stanford Vision and Learning Lab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/phantom/submodules/phantom-robomimic/MANIFEST.in b/phantom/submodules/phantom-robomimic/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..03d46cbc6f2c4d1ebfbba8c8049fa04342f9defd
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/MANIFEST.in
@@ -0,0 +1,9 @@
+include robomimic/exps/templates/*.json
+include robomimic/scripts/*.py
+include robomimic/scripts/*.sh
+include robomimic/scripts/conversion/*.py
+include robomimic/scripts/conversion/*.sh
+recursive-include examples/ *.py
+recursive-include tests/ *.py
+recursive-include tests/ *.sh
+recursive-include tests/assets/ *
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/README.md b/phantom/submodules/phantom-robomimic/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bdc3556ff9439556f124d83502669aba42d621b0
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/README.md
@@ -0,0 +1,90 @@
+# robomimic
+
+
+
+
+
+
+
+
+
+
+
+
+[**[Homepage]**](https://robomimic.github.io/) [**[Documentation]**](https://robomimic.github.io/docs/introduction/overview.html) [**[Study Paper]**](https://arxiv.org/abs/2108.03298) [**[Study Website]**](https://robomimic.github.io/study/) [**[ARISE Initiative]**](https://github.com/ARISE-Initiative)
+
+-------
+## Latest Updates
+- [10/11/2023] **v0.3.1**: support for extracting, training on, and visualizing depth observations for robosuite datasets
+- [07/03/2023] **v0.3.0**: BC-Transformer and IQL :brain:, support for DeepMind MuJoCo bindings :robot:, pre-trained image reps :eye:, wandb logging :chart_with_upwards_trend:, and more
+- [05/23/2022] **v0.2.1**: Updated website and documentation to feature more tutorials :notebook_with_decorative_cover:
+- [12/16/2021] **v0.2.0**: Modular observation modalities and encoders :wrench:, support for [MOMART](https://sites.google.com/view/il-for-mm/home) datasets :open_file_folder: [[release notes]](https://github.com/ARISE-Initiative/robomimic/releases/tag/v0.2.0) [[documentation]](https://robomimic.github.io/docs/v0.2/introduction/overview.html)
+- [08/09/2021] **v0.1.0**: Initial code and paper release
+
+-------
+
+## Colab quickstart
+Get started with a quick colab notebook demo of robomimic without installing anything locally.
+
+[](https://colab.research.google.com/drive/1b62r_km9pP40fKF0cBdpdTO2P_2eIbC6?usp=sharing)
+
+
+-------
+
+**robomimic** is a framework for robot learning from demonstration.
+It offers a broad set of demonstration datasets collected on robot manipulation domains and offline learning algorithms to learn from these datasets.
+**robomimic** aims to make robot learning broadly *accessible* and *reproducible*, allowing researchers and practitioners to benchmark tasks and algorithms fairly and to develop the next generation of robot learning algorithms.
+
+## Core Features
+
+
+
+
+
+
+
+
+## Reproducing benchmarks
+
+The robomimic framework also makes reproducing the results from different benchmarks and datasets easy. See the [datasets page](https://robomimic.github.io/docs/datasets/overview.html) for more information on downloading datasets and reproducing experiments.
+
+## Troubleshooting
+
+Please see the [troubleshooting](https://robomimic.github.io/docs/miscellaneous/troubleshooting.html) section for common fixes, or [submit an issue](https://github.com/ARISE-Initiative/robomimic/issues) on our github page.
+
+## Contributing to robomimic
+This project is part of the broader [Advancing Robot Intelligence through Simulated Environments (ARISE) Initiative](https://github.com/ARISE-Initiative), with the aim of lowering the barriers of entry for cutting-edge research at the intersection of AI and Robotics.
+The project originally began development in late 2018 by researchers in the [Stanford Vision and Learning Lab](http://svl.stanford.edu/) (SVL).
+Now it is actively maintained and used for robotics research projects across multiple labs.
+We welcome community contributions to this project.
+For details please check our [contributing guidelines](https://robomimic.github.io/docs/miscellaneous/contributing.html).
+
+## Citation
+
+Please cite [this paper](https://arxiv.org/abs/2108.03298) if you use this framework in your work:
+
+```bibtex
+@inproceedings{robomimic2021,
+ title={What Matters in Learning from Offline Human Demonstrations for Robot Manipulation},
+ author={Ajay Mandlekar and Danfei Xu and Josiah Wong and Soroush Nasiriany and Chen Wang and Rohun Kulkarni and Li Fei-Fei and Silvio Savarese and Yuke Zhu and Roberto Mart\'{i}n-Mart\'{i}n},
+ booktitle={Conference on Robot Learning (CoRL)},
+ year={2021}
+}
+```
diff --git a/phantom/submodules/phantom-robomimic/requirements-docs.txt b/phantom/submodules/phantom-robomimic/requirements-docs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4b0538e24a3a648ba885ddf9d42b39c26f48087a
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/requirements-docs.txt
@@ -0,0 +1,8 @@
+# requirements for building sphinx docs
+pygments==2.4.1
+sphinx
+sphinx_rtd_theme
+sphinx_markdown_tables
+sphinx_book_theme
+recommonmark
+nbsphinx
diff --git a/phantom/submodules/phantom-robomimic/requirements.txt b/phantom/submodules/phantom-robomimic/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ac8ea89aa900a793eff45c57b39f3137e9edb872
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/requirements.txt
@@ -0,0 +1,14 @@
+numpy>=1.13.3
+h5py
+psutil
+tqdm
+termcolor
+tensorboard
+tensorboardX
+imageio
+imageio-ffmpeg
+matplotlib
+egl_probe>=1.0.1
+torch
+torchvision
+diffusers==0.11.1
diff --git a/phantom/submodules/phantom-robomimic/robomimic/__init__.py b/phantom/submodules/phantom-robomimic/robomimic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1930630a3c7354ea8c9453aa3a4b280cebb2eceb
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/__init__.py
@@ -0,0 +1,159 @@
+__version__ = "0.3.1"
+
+
+# stores released dataset links and rollout horizons in global dictionary.
+# Structure is given below for each type of dataset:
+
+# robosuite / real
+# {
+# task:
+# dataset_type:
+# hdf5_type:
+# url: link
+# horizon: value
+# ...
+# ...
+# ...
+# }
+DATASET_REGISTRY = {}
+
+# momart
+# {
+# task:
+# dataset_type:
+# url: link
+# size: value
+# ...
+# ...
+# }
+MOMART_DATASET_REGISTRY = {}
+
+
+def register_dataset_link(task, dataset_type, hdf5_type, link, horizon):
+ """
+ Helper function to register dataset link in global dictionary.
+ Also takes a @horizon parameter - this corresponds to the evaluation
+ rollout horizon that should be used during training.
+
+ Args:
+ task (str): name of task for this dataset
+ dataset_type (str): type of dataset (usually identifies the dataset source)
+ hdf5_type (str): type of hdf5 - usually one of "raw", "low_dim", or "image",
+ to identify the kind of observations in the dataset
+ link (str): download link for the dataset
+ horizon (int): evaluation rollout horizon that should be used with this dataset
+ """
+ if task not in DATASET_REGISTRY:
+ DATASET_REGISTRY[task] = {}
+ if dataset_type not in DATASET_REGISTRY[task]:
+ DATASET_REGISTRY[task][dataset_type] = {}
+ DATASET_REGISTRY[task][dataset_type][hdf5_type] = dict(url=link, horizon=horizon)
+
+
+def register_all_links():
+ """
+ Record all dataset links in this function.
+ """
+
+ # all proficient human datasets
+ ph_tasks = ["lift", "can", "square", "transport", "tool_hang", "lift_real", "can_real", "tool_hang_real"]
+ ph_horizons = [400, 400, 400, 700, 700, 1000, 1000, 1000]
+ for task, horizon in zip(ph_tasks, ph_horizons):
+ register_dataset_link(task=task, dataset_type="ph", hdf5_type="raw", horizon=horizon,
+ link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/ph/demo{}.hdf5".format(
+ task, "" if "real" in task else "_v141"
+ )
+ )
+ # real world datasets only have demo.hdf5 files which already contain all observation modalities
+ # while sim datasets store raw low-dim mujoco states in the demo.hdf5
+ if "real" not in task:
+ register_dataset_link(task=task, dataset_type="ph", hdf5_type="low_dim", horizon=horizon,
+ link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/ph/low_dim_v141.hdf5".format(task))
+ register_dataset_link(task=task, dataset_type="ph", hdf5_type="image", horizon=horizon,
+ link=None)
+
+ # all multi human datasets
+ mh_tasks = ["lift", "can", "square", "transport"]
+ mh_horizons = [500, 500, 500, 1100]
+ for task, horizon in zip(mh_tasks, mh_horizons):
+ register_dataset_link(task=task, dataset_type="mh", hdf5_type="raw", horizon=horizon,
+ link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mh/demo_v141.hdf5".format(task))
+ register_dataset_link(task=task, dataset_type="mh", hdf5_type="low_dim", horizon=horizon,
+ link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mh/low_dim_v141.hdf5".format(task))
+ register_dataset_link(task=task, dataset_type="mh", hdf5_type="image", horizon=horizon,
+ link=None)
+
+ # all machine generated datasets
+ for task, horizon in zip(["lift", "can"], [400, 400]):
+ register_dataset_link(task=task, dataset_type="mg", hdf5_type="raw", horizon=horizon,
+ link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mg/demo_v141.hdf5".format(task))
+ register_dataset_link(task=task, dataset_type="mg", hdf5_type="low_dim_sparse", horizon=horizon,
+ link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mg/low_dim_sparse_v141.hdf5".format(task))
+ register_dataset_link(task=task, dataset_type="mg", hdf5_type="image_sparse", horizon=horizon,
+ link=None)
+ register_dataset_link(task=task, dataset_type="mg", hdf5_type="low_dim_dense", horizon=horizon,
+ link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mg/low_dim_dense_v141.hdf5".format(task))
+ register_dataset_link(task=task, dataset_type="mg", hdf5_type="image_dense", horizon=horizon,
+ link=None)
+
+ # can-paired dataset
+ register_dataset_link(task="can", dataset_type="paired", hdf5_type="raw", horizon=400,
+ link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/can/paired/demo_v141.hdf5")
+ register_dataset_link(task="can", dataset_type="paired", hdf5_type="low_dim", horizon=400,
+ link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/can/paired/low_dim_v141.hdf5")
+ register_dataset_link(task="can", dataset_type="paired", hdf5_type="image", horizon=400,
+ link=None)
+
+
+def register_momart_dataset_link(task, dataset_type, link, dataset_size):
+ """
+ Helper function to register dataset link in global dictionary.
+ Also takes a @horizon parameter - this corresponds to the evaluation
+ rollout horizon that should be used during training.
+
+ Args:
+ task (str): name of task for this dataset
+ dataset_type (str): type of dataset (usually identifies the dataset source)
+ link (str): download link for the dataset
+ dataset_size (float): size of the dataset, in GB
+ """
+ if task not in MOMART_DATASET_REGISTRY:
+ MOMART_DATASET_REGISTRY[task] = {}
+ if dataset_type not in MOMART_DATASET_REGISTRY[task]:
+ MOMART_DATASET_REGISTRY[task][dataset_type] = {}
+ MOMART_DATASET_REGISTRY[task][dataset_type] = dict(url=link, size=dataset_size)
+
+
+def register_all_momart_links():
+ """
+ Record all dataset links in this function.
+ """
+ # all tasks, mapped to their [exp, sub, gen, sam] sizes
+ momart_tasks = {
+ "table_setup_from_dishwasher": [14, 14, 3.3, 0.6],
+ "table_setup_from_dresser": [16, 17, 3.1, 0.7],
+ "table_cleanup_to_dishwasher": [23, 36, 5.3, 1.1],
+ "table_cleanup_to_sink": [17, 28, 2.9, 0.8],
+ "unload_dishwasher": [21, 27, 5.4, 1.0],
+ }
+
+ momart_dataset_types = [
+ "expert",
+ "suboptimal",
+ "generalize",
+ "sample",
+ ]
+
+ # Iterate over all combos and register the link
+ for task, dataset_sizes in momart_tasks.items():
+ for dataset_type, dataset_size in zip(momart_dataset_types, dataset_sizes):
+ register_momart_dataset_link(
+ task=task,
+ dataset_type=dataset_type,
+ link=f"http://downloads.cs.stanford.edu/downloads/rt_mm/{dataset_type}/{task}_{dataset_type}.hdf5",
+ dataset_size=dataset_size,
+ )
+
+
+register_all_links()
+register_all_momart_links()
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/__init__.py b/phantom/submodules/phantom-robomimic/robomimic/algo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dedba35c9b70e100cae7da46580720f35dc28ef1
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/__init__.py
@@ -0,0 +1,12 @@
+from robomimic.algo.algo import register_algo_factory_func, algo_name_to_factory_func, algo_factory, Algo, PolicyAlgo, ValueAlgo, PlannerAlgo, HierarchicalAlgo, RolloutPolicy
+
+# note: these imports are needed to register these classes in the global algo registry
+from robomimic.algo.bc import BC, BC_Gaussian, BC_GMM, BC_VAE, BC_RNN, BC_RNN_GMM
+from robomimic.algo.bcq import BCQ, BCQ_GMM, BCQ_Distributional
+from robomimic.algo.cql import CQL
+from robomimic.algo.iql import IQL
+from robomimic.algo.gl import GL, GL_VAE, ValuePlanner
+from robomimic.algo.hbc import HBC
+from robomimic.algo.iris import IRIS
+from robomimic.algo.td3_bc import TD3_BC
+from robomimic.algo.diffusion_policy import DiffusionPolicyUNet
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/algo.py b/phantom/submodules/phantom-robomimic/robomimic/algo/algo.py
new file mode 100644
index 0000000000000000000000000000000000000000..6289c214e15edbff66df6bfaceef25921ffb3b3b
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/algo.py
@@ -0,0 +1,574 @@
+"""
+This file contains base classes that other algorithm classes subclass.
+Each algorithm file also implements a algorithm factory function that
+takes in an algorithm config (`config.algo`) and returns the particular
+Algo subclass that should be instantiated, along with any extra kwargs.
+These factory functions are registered into a global dictionary with the
+@register_algo_factory_func function decorator. This makes it easy for
+@algo_factory to instantiate the correct `Algo` subclass.
+"""
+import textwrap
+from copy import deepcopy
+from collections import OrderedDict
+
+import torch.nn as nn
+import torch
+
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.action_utils as AcUtils
+
+
+# mapping from algo name to factory functions that map algo configs to algo class names
+REGISTERED_ALGO_FACTORY_FUNCS = OrderedDict()
+
+
+def register_algo_factory_func(algo_name):
+ """
+ Function decorator to register algo factory functions that map algo configs to algo class names.
+ Each algorithm implements such a function, and decorates it with this decorator.
+
+ Args:
+ algo_name (str): the algorithm name to register the algorithm under
+ """
+ def decorator(factory_func):
+ REGISTERED_ALGO_FACTORY_FUNCS[algo_name] = factory_func
+ return decorator
+
+
+def algo_name_to_factory_func(algo_name):
+ """
+ Uses registry to retrieve algo factory function from algo name.
+
+ Args:
+ algo_name (str): the algorithm name
+ """
+ return REGISTERED_ALGO_FACTORY_FUNCS[algo_name]
+
+
+def algo_factory(algo_name, config, obs_key_shapes, ac_dim, device):
+ """
+ Factory function for creating algorithms based on the algorithm name and config.
+
+ Args:
+ algo_name (str): the algorithm name
+
+ config (BaseConfig instance): config object
+
+ obs_key_shapes (OrderedDict): dictionary that maps observation keys to shapes
+
+ ac_dim (int): dimension of action space
+
+ device (torch.Device): where the algo should live (i.e. cpu, gpu)
+ """
+
+ # @algo_name is included as an arg to be explicit, but make sure it matches the config
+ assert algo_name == config.algo_name
+
+ # use algo factory func to get algo class and kwargs from algo config
+ factory_func = algo_name_to_factory_func(algo_name)
+ algo_cls, algo_kwargs = factory_func(config.algo)
+
+ # create algo instance
+ return algo_cls(
+ algo_config=config.algo,
+ obs_config=config.observation,
+ global_config=config,
+ obs_key_shapes=obs_key_shapes,
+ ac_dim=ac_dim,
+ device=device,
+ **algo_kwargs
+ )
+
+
+class Algo(object):
+ """
+ Base algorithm class that all other algorithms subclass. Defines several
+ functions that should be overriden by subclasses, in order to provide
+ a standard API to be used by training functions such as @run_epoch in
+ utils/train_utils.py.
+ """
+ def __init__(
+ self,
+ algo_config,
+ obs_config,
+ global_config,
+ obs_key_shapes,
+ ac_dim,
+ device
+ ):
+ """
+ Args:
+ algo_config (Config object): instance of Config corresponding to the algo section
+ of the config
+
+ obs_config (Config object): instance of Config corresponding to the observation
+ section of the config
+
+ global_config (Config object): global training config
+
+ obs_key_shapes (OrderedDict): dictionary that maps observation keys to shapes
+
+ ac_dim (int): dimension of action space
+
+ device (torch.Device): where the algo should live (i.e. cpu, gpu)
+ """
+ self.optim_params = deepcopy(algo_config.optim_params)
+ self.algo_config = algo_config
+ self.obs_config = obs_config
+ self.global_config = global_config
+
+ self.ac_dim = ac_dim
+ self.device = device
+ self.obs_key_shapes = obs_key_shapes
+
+ self.nets = nn.ModuleDict()
+ self._create_shapes(obs_config.modalities, obs_key_shapes)
+ self._create_networks()
+ self._create_optimizers()
+ assert isinstance(self.nets, nn.ModuleDict)
+
+ def _create_shapes(self, obs_keys, obs_key_shapes):
+ """
+ Create obs_shapes, goal_shapes, and subgoal_shapes dictionaries, to make it
+ easy for this algorithm object to keep track of observation key shapes. Each dictionary
+ maps observation key to shape.
+
+ Args:
+ obs_keys (dict): dict of required observation keys for this training run (usually
+ specified by the obs config), e.g., {"obs": ["rgb", "proprio"], "goal": ["proprio"]}
+ obs_key_shapes (dict): dict of observation key shapes, e.g., {"rgb": [3, 224, 224]}
+ """
+ # determine shapes
+ self.obs_shapes = OrderedDict()
+ self.goal_shapes = OrderedDict()
+ self.subgoal_shapes = OrderedDict()
+
+ # We check across all modality groups (obs, goal, subgoal), and see if the inputted observation key exists
+ # across all modalitie specified in the config. If so, we store its corresponding shape internally
+ for k in obs_key_shapes:
+ if "obs" in self.obs_config.modalities and k in [obs_key for modality in self.obs_config.modalities.obs.values() for obs_key in modality]:
+ self.obs_shapes[k] = obs_key_shapes[k]
+ if "goal" in self.obs_config.modalities and k in [obs_key for modality in self.obs_config.modalities.goal.values() for obs_key in modality]:
+ self.goal_shapes[k] = obs_key_shapes[k]
+ if "subgoal" in self.obs_config.modalities and k in [obs_key for modality in self.obs_config.modalities.subgoal.values() for obs_key in modality]:
+ self.subgoal_shapes[k] = obs_key_shapes[k]
+
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ @self.nets should be a ModuleDict.
+ """
+ raise NotImplementedError
+
+ def _create_optimizers(self):
+ """
+ Creates optimizers using @self.optim_params and places them into @self.optimizers.
+ """
+ self.optimizers = dict()
+ self.lr_schedulers = dict()
+
+ for k in self.optim_params:
+ # only make optimizers for networks that have been created - @optim_params may have more
+ # settings for unused networks
+ if k in self.nets:
+ if isinstance(self.nets[k], nn.ModuleList):
+ self.optimizers[k] = [
+ TorchUtils.optimizer_from_optim_params(net_optim_params=self.optim_params[k], net=self.nets[k][i])
+ for i in range(len(self.nets[k]))
+ ]
+ self.lr_schedulers[k] = [
+ TorchUtils.lr_scheduler_from_optim_params(net_optim_params=self.optim_params[k], net=self.nets[k][i], optimizer=self.optimizers[k][i])
+ for i in range(len(self.nets[k]))
+ ]
+ else:
+ self.optimizers[k] = TorchUtils.optimizer_from_optim_params(
+ net_optim_params=self.optim_params[k], net=self.nets[k])
+ self.lr_schedulers[k] = TorchUtils.lr_scheduler_from_optim_params(
+ net_optim_params=self.optim_params[k], net=self.nets[k], optimizer=self.optimizers[k])
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ return batch
+
+ def postprocess_batch_for_training(self, batch, obs_normalization_stats):
+ """
+ Does some operations (like channel swap, uint8 to float conversion, normalization)
+ after @process_batch_for_training is called, in order to ensure these operations
+ take place on GPU.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader. Assumed to be on the device where
+ training will occur (after @process_batch_for_training
+ is called)
+
+ obs_normalization_stats (dict or None): if provided, this should map observation
+ keys to dicts with a "mean" and "std" of shape (1, ...) where ... is the
+ default shape for the observation.
+
+ Returns:
+ batch (dict): postproceesed batch
+ """
+
+ # ensure obs_normalization_stats are torch Tensors on proper device
+ obs_normalization_stats = TensorUtils.to_float(TensorUtils.to_device(TensorUtils.to_tensor(obs_normalization_stats), self.device))
+
+ obs_keys = ["obs", "next_obs", "goal_obs"]
+ for k in obs_keys:
+ if k in batch and batch[k] is not None:
+ batch[k] = ObsUtils.process_obs_dict(batch[k])
+ if obs_normalization_stats is not None:
+ batch[k] = ObsUtils.normalize_dict(batch[k], obs_normalization_stats=obs_normalization_stats)
+ return batch
+
+ def postprocess_batch_for_training(self, batch, obs_normalization_stats):
+ """
+ Does some operations (like channel swap, uint8 to float conversion, normalization)
+ after @process_batch_for_training is called, in order to ensure these operations
+ take place on GPU.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader. Assumed to be on the device where
+ training will occur (after @process_batch_for_training
+ is called)
+
+ obs_normalization_stats (dict or None): if provided, this should map observation
+ keys to dicts with a "mean" and "std" of shape (1, ...) where ... is the
+ default shape for the observation.
+
+ Returns:
+ batch (dict): postproceesed batch
+ """
+
+ # ensure obs_normalization_stats are torch Tensors on proper device
+ obs_normalization_stats = TensorUtils.to_float(TensorUtils.to_device(TensorUtils.to_tensor(obs_normalization_stats), self.device))
+
+ # we will search the nested batch dictionary for the following special batch dict keys
+ # and apply the processing function to their values (which correspond to observations)
+ obs_keys = ["obs", "next_obs", "goal_obs"]
+
+ def recurse_helper(d):
+ """
+ Apply process_obs_dict to values in nested dictionary d that match a key in obs_keys.
+ """
+ for k in d:
+ if k in obs_keys:
+ # found key - stop search and process observation
+ if d[k] is not None:
+ d[k] = ObsUtils.process_obs_dict(d[k])
+ if obs_normalization_stats is not None:
+ d[k] = ObsUtils.normalize_dict(d[k], obs_normalization_stats=obs_normalization_stats)
+ elif isinstance(d[k], dict):
+ # search down into dictionary
+ recurse_helper(d[k])
+
+ recurse_helper(batch)
+ return batch
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ assert validate or self.nets.training
+ return OrderedDict()
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss log (dict): name -> summary statistic
+ """
+ log = OrderedDict()
+
+ # record current optimizer learning rates
+ for k in self.optimizers:
+ for i, param_group in enumerate(self.optimizers[k].param_groups):
+ log["Optimizer/{}{}_lr".format(k, i)] = param_group["lr"]
+
+ return log
+
+ def on_epoch_end(self, epoch):
+ """
+ Called at the end of each epoch.
+ """
+
+ # LR scheduling updates
+ for k in self.lr_schedulers:
+ if self.lr_schedulers[k] is not None:
+ self.lr_schedulers[k].step()
+
+ def set_eval(self):
+ """
+ Prepare networks for evaluation.
+ """
+ self.nets.eval()
+
+ def set_train(self):
+ """
+ Prepare networks for training.
+ """
+ self.nets.train()
+
+ def serialize(self):
+ """
+ Get dictionary of current model parameters.
+ """
+ return self.nets.state_dict()
+
+ def deserialize(self, model_dict):
+ """
+ Load model from a checkpoint.
+
+ Args:
+ model_dict (dict): a dictionary saved by self.serialize() that contains
+ the same keys as @self.network_classes
+ """
+ self.nets.load_state_dict(model_dict)
+
+ def __repr__(self):
+ """
+ Pretty print algorithm and network description.
+ """
+ return "{} (\n".format(self.__class__.__name__) + \
+ textwrap.indent(self.nets.__repr__(), ' ') + "\n)"
+
+ def reset(self):
+ """
+ Reset algo state to prepare for environment rollouts.
+ """
+ pass
+
+
+class PolicyAlgo(Algo):
+ """
+ Base class for all algorithms that can be used as policies.
+ """
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ raise NotImplementedError
+
+
+class ValueAlgo(Algo):
+ """
+ Base class for all algorithms that can learn a value function.
+ """
+ def get_state_value(self, obs_dict, goal_dict=None):
+ """
+ Get state value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ raise NotImplementedError
+
+ def get_state_action_value(self, obs_dict, actions, goal_dict=None):
+ """
+ Get state-action value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ actions (torch.Tensor): action
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ raise NotImplementedError
+
+
+class PlannerAlgo(Algo):
+ """
+ Base class for all algorithms that can be used for planning subgoals
+ conditioned on current observations and potential goal observations.
+ """
+ def get_subgoal_predictions(self, obs_dict, goal_dict=None):
+ """
+ Get predicted subgoal outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ subgoal prediction (dict): name -> Tensor [batch_size, ...]
+ """
+ raise NotImplementedError
+
+ def sample_subgoals(self, obs_dict, goal_dict, num_samples=1):
+ """
+ For planners that rely on sampling subgoals.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ subgoals (dict): name -> Tensor [batch_size, num_samples, ...]
+ """
+ raise NotImplementedError
+
+
+class HierarchicalAlgo(Algo):
+ """
+ Base class for all hierarchical algorithms that consist of (1) subgoal planning
+ and (2) subgoal-conditioned policy learning.
+ """
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ raise NotImplementedError
+
+ def get_subgoal_predictions(self, obs_dict, goal_dict=None):
+ """
+ Get subgoal predictions from high-level subgoal planner.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ subgoal (dict): predicted subgoal
+ """
+ raise NotImplementedError
+
+ @property
+ def current_subgoal(self):
+ """
+ Get the current subgoal for conditioning the low-level policy
+
+ Returns:
+ current subgoal (dict): predicted subgoal
+ """
+ raise NotImplementedError
+
+
+class RolloutPolicy(object):
+ """
+ Wraps @Algo object to make it easy to run policies in a rollout loop.
+ """
+ def __init__(self, policy, obs_normalization_stats=None, action_normalization_stats=None):
+ """
+ Args:
+ policy (Algo instance): @Algo object to wrap to prepare for rollouts
+
+ obs_normalization_stats (dict): optionally pass a dictionary for observation
+ normalization. This should map observation keys to dicts
+ with a "mean" and "std" of shape (1, ...) where ... is the default
+ shape for the observation.
+ """
+ self.policy = policy
+ self.obs_normalization_stats = obs_normalization_stats
+ self.action_normalization_stats = action_normalization_stats
+
+ def start_episode(self):
+ """
+ Prepare the policy to start a new rollout.
+ """
+ self.policy.set_eval()
+ self.policy.reset()
+
+ def _prepare_observation(self, ob):
+ """
+ Prepare raw observation dict from environment for policy.
+
+ Args:
+ ob (dict): single observation dictionary from environment (no batch dimension,
+ and np.array values for each key)
+ """
+ ob = TensorUtils.to_tensor(ob)
+ ob = TensorUtils.to_batch(ob)
+ ob = TensorUtils.to_device(ob, self.policy.device)
+ ob = TensorUtils.to_float(ob)
+ if self.obs_normalization_stats is not None:
+ # ensure obs_normalization_stats are torch Tensors on proper device
+ obs_normalization_stats = TensorUtils.to_float(TensorUtils.to_device(TensorUtils.to_tensor(self.obs_normalization_stats), self.policy.device))
+ # limit normalization to obs keys being used, in case environment includes extra keys
+ ob = { k : ob[k] for k in self.policy.global_config.all_obs_keys }
+ ob = ObsUtils.normalize_dict(ob, normalization_stats=obs_normalization_stats)
+ return ob
+
+ def __repr__(self):
+ """Pretty print network description"""
+ return self.policy.__repr__()
+
+ def __call__(self, ob, goal=None):
+ """
+ Produce action from raw observation dict (and maybe goal dict) from environment.
+
+ Args:
+ ob (dict): single observation dictionary from environment (no batch dimension,
+ and np.array values for each key)
+ goal (dict): goal observation
+ """
+ ob = self._prepare_observation(ob)
+ if goal is not None:
+ goal = self._prepare_observation(goal)
+ ac = self.policy.get_action(obs_dict=ob, goal_dict=goal)
+ ac = TensorUtils.to_numpy(ac[0])
+ if self.action_normalization_stats is not None:
+ action_keys = self.policy.global_config.train.action_keys
+ action_shapes = {k: self.action_normalization_stats[k]["offset"].shape[1:] for k in self.action_normalization_stats}
+ ac_dict = AcUtils.vector_to_action_dict(ac, action_shapes=action_shapes, action_keys=action_keys)
+ ac_dict = ObsUtils.unnormalize_dict(ac_dict, normalization_stats=self.action_normalization_stats)
+ action_config = self.policy.global_config.train.action_config
+ for key, value in ac_dict.items():
+ this_format = action_config[key].get('format', None)
+ if this_format == 'rot_6d':
+ rot_6d = torch.from_numpy(value).unsqueeze(0)
+ rot = TorchUtils.rot_6d_to_axis_angle(rot_6d=rot_6d).squeeze().numpy()
+ ac_dict[key] = rot
+ ac = AcUtils.action_dict_to_vector(ac_dict, action_keys=action_keys)
+ return ac
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/bc.py b/phantom/submodules/phantom-robomimic/robomimic/algo/bc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0797b7eae94b792833857ca6e04958084d4001dc
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/bc.py
@@ -0,0 +1,875 @@
+"""
+Implementation of Behavioral Cloning (BC).
+"""
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributions as D
+
+import robomimic.models.base_nets as BaseNets
+import robomimic.models.obs_nets as ObsNets
+import robomimic.models.policy_nets as PolicyNets
+import robomimic.models.vae_nets as VAENets
+import robomimic.utils.loss_utils as LossUtils
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+import robomimic.utils.obs_utils as ObsUtils
+
+from robomimic.algo import register_algo_factory_func, PolicyAlgo
+
+
+@register_algo_factory_func("bc")
+def algo_config_to_class(algo_config):
+ """
+ Maps algo config to the BC algo class to instantiate, along with additional algo kwargs.
+
+ Args:
+ algo_config (Config instance): algo config
+
+ Returns:
+ algo_class: subclass of Algo
+ algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
+ """
+
+ # note: we need the check below because some configs import BCConfig and exclude
+ # some of these options
+ gaussian_enabled = ("gaussian" in algo_config and algo_config.gaussian.enabled)
+ gmm_enabled = ("gmm" in algo_config and algo_config.gmm.enabled)
+ vae_enabled = ("vae" in algo_config and algo_config.vae.enabled)
+
+ rnn_enabled = algo_config.rnn.enabled
+ # support legacy configs that do not have "transformer" item
+ transformer_enabled = ("transformer" in algo_config) and algo_config.transformer.enabled
+
+ if gaussian_enabled:
+ if rnn_enabled:
+ raise NotImplementedError
+ elif transformer_enabled:
+ raise NotImplementedError
+ else:
+ algo_class, algo_kwargs = BC_Gaussian, {}
+ elif gmm_enabled:
+ if rnn_enabled:
+ algo_class, algo_kwargs = BC_RNN_GMM, {}
+ elif transformer_enabled:
+ algo_class, algo_kwargs = BC_Transformer_GMM, {}
+ else:
+ algo_class, algo_kwargs = BC_GMM, {}
+ elif vae_enabled:
+ if rnn_enabled:
+ raise NotImplementedError
+ elif transformer_enabled:
+ raise NotImplementedError
+ else:
+ algo_class, algo_kwargs = BC_VAE, {}
+ else:
+ if rnn_enabled:
+ algo_class, algo_kwargs = BC_RNN, {}
+ elif transformer_enabled:
+ algo_class, algo_kwargs = BC_Transformer, {}
+ else:
+ algo_class, algo_kwargs = BC, {}
+
+ return algo_class, algo_kwargs
+
+
+class BC(PolicyAlgo):
+ """
+ Normal BC training.
+ """
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ self.nets = nn.ModuleDict()
+ self.nets["policy"] = PolicyNets.ActorNetwork(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.actor_layer_dims,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+ self.nets = self.nets.float().to(self.device)
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+ input_batch["obs"] = {k: batch["obs"][k][:, 0, :] for k in batch["obs"]}
+ input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
+ input_batch["actions"] = batch["actions"][:, 0, :]
+ # we move to device first before float conversion because image observation modalities will be uint8 -
+ # this minimizes the amount of data transferred to GPU
+ return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
+
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ with TorchUtils.maybe_no_grad(no_grad=validate):
+ info = super(BC, self).train_on_batch(batch, epoch, validate=validate)
+ predictions = self._forward_training(batch)
+ losses = self._compute_losses(predictions, batch)
+
+ info["predictions"] = TensorUtils.detach(predictions)
+ info["losses"] = TensorUtils.detach(losses)
+
+ if not validate:
+ step_info = self._train_step(losses)
+ info.update(step_info)
+
+ return info
+
+ def _forward_training(self, batch):
+ """
+ Internal helper function for BC algo class. Compute forward pass
+ and return network outputs in @predictions dict.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ predictions (dict): dictionary containing network outputs
+ """
+ predictions = OrderedDict()
+ actions = self.nets["policy"](obs_dict=batch["obs"], goal_dict=batch["goal_obs"])
+ predictions["actions"] = actions
+ return predictions
+
+ def _compute_losses(self, predictions, batch):
+ """
+ Internal helper function for BC algo class. Compute losses based on
+ network outputs in @predictions dict, using reference labels in @batch.
+
+ Args:
+ predictions (dict): dictionary containing network outputs, from @_forward_training
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ losses (dict): dictionary of losses computed over the batch
+ """
+ losses = OrderedDict()
+ a_target = batch["actions"]
+ actions = predictions["actions"]
+ losses["l2_loss"] = nn.MSELoss()(actions, a_target)
+ losses["l1_loss"] = nn.SmoothL1Loss()(actions, a_target)
+ # cosine direction loss on eef delta position
+ losses["cos_loss"] = LossUtils.cosine_loss(actions[..., :3], a_target[..., :3])
+
+ action_losses = [
+ self.algo_config.loss.l2_weight * losses["l2_loss"],
+ self.algo_config.loss.l1_weight * losses["l1_loss"],
+ self.algo_config.loss.cos_weight * losses["cos_loss"],
+ ]
+ action_loss = sum(action_losses)
+ losses["action_loss"] = action_loss
+ return losses
+
+ def _train_step(self, losses):
+ """
+ Internal helper function for BC algo class. Perform backpropagation on the
+ loss tensors in @losses to update networks.
+
+ Args:
+ losses (dict): dictionary of losses computed over the batch, from @_compute_losses
+ """
+
+ # gradient step
+ info = OrderedDict()
+ policy_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["policy"],
+ optim=self.optimizers["policy"],
+ loss=losses["action_loss"],
+ )
+ info["policy_grad_norms"] = policy_grad_norms
+ return info
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ log = super(BC, self).log_info(info)
+ log["Loss"] = info["losses"]["action_loss"].item()
+ if "l2_loss" in info["losses"]:
+ log["L2_Loss"] = info["losses"]["l2_loss"].item()
+ if "l1_loss" in info["losses"]:
+ log["L1_Loss"] = info["losses"]["l1_loss"].item()
+ if "cos_loss" in info["losses"]:
+ log["Cosine_Loss"] = info["losses"]["cos_loss"].item()
+ if "policy_grad_norms" in info:
+ log["Policy_Grad_Norms"] = info["policy_grad_norms"]
+ return log
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ assert not self.nets.training
+ return self.nets["policy"](obs_dict, goal_dict=goal_dict)
+
+
+class BC_Gaussian(BC):
+ """
+ BC training with a Gaussian policy.
+ """
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ assert self.algo_config.gaussian.enabled
+
+ self.nets = nn.ModuleDict()
+ self.nets["policy"] = PolicyNets.GaussianActorNetwork(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.actor_layer_dims,
+ fixed_std=self.algo_config.gaussian.fixed_std,
+ init_std=self.algo_config.gaussian.init_std,
+ std_limits=(self.algo_config.gaussian.min_std, 7.5),
+ std_activation=self.algo_config.gaussian.std_activation,
+ low_noise_eval=self.algo_config.gaussian.low_noise_eval,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ self.nets = self.nets.float().to(self.device)
+
+ def _forward_training(self, batch):
+ """
+ Internal helper function for BC algo class. Compute forward pass
+ and return network outputs in @predictions dict.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ predictions (dict): dictionary containing network outputs
+ """
+ dists = self.nets["policy"].forward_train(
+ obs_dict=batch["obs"],
+ goal_dict=batch["goal_obs"],
+ )
+
+ # make sure that this is a batch of multivariate action distributions, so that
+ # the log probability computation will be correct
+ assert len(dists.batch_shape) == 1
+ log_probs = dists.log_prob(batch["actions"])
+
+ predictions = OrderedDict(
+ log_probs=log_probs,
+ )
+ return predictions
+
+ def _compute_losses(self, predictions, batch):
+ """
+ Internal helper function for BC algo class. Compute losses based on
+ network outputs in @predictions dict, using reference labels in @batch.
+
+ Args:
+ predictions (dict): dictionary containing network outputs, from @_forward_training
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ losses (dict): dictionary of losses computed over the batch
+ """
+
+ # loss is just negative log-likelihood of action targets
+ action_loss = -predictions["log_probs"].mean()
+ return OrderedDict(
+ log_probs=-action_loss,
+ action_loss=action_loss,
+ )
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ log = PolicyAlgo.log_info(self, info)
+ log["Loss"] = info["losses"]["action_loss"].item()
+ log["Log_Likelihood"] = info["losses"]["log_probs"].item()
+ if "policy_grad_norms" in info:
+ log["Policy_Grad_Norms"] = info["policy_grad_norms"]
+ return log
+
+
+class BC_GMM(BC_Gaussian):
+ """
+ BC training with a Gaussian Mixture Model policy.
+ """
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ assert self.algo_config.gmm.enabled
+
+ self.nets = nn.ModuleDict()
+ self.nets["policy"] = PolicyNets.GMMActorNetwork(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.actor_layer_dims,
+ num_modes=self.algo_config.gmm.num_modes,
+ min_std=self.algo_config.gmm.min_std,
+ std_activation=self.algo_config.gmm.std_activation,
+ low_noise_eval=self.algo_config.gmm.low_noise_eval,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ self.nets = self.nets.float().to(self.device)
+
+
+class BC_VAE(BC):
+ """
+ BC training with a VAE policy.
+ """
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ self.nets = nn.ModuleDict()
+ self.nets["policy"] = PolicyNets.VAEActor(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ device=self.device,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ **VAENets.vae_args_from_config(self.algo_config.vae),
+ )
+
+ self.nets = self.nets.float().to(self.device)
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Update from superclass to set categorical temperature, for categorical VAEs.
+ """
+ if self.algo_config.vae.prior.use_categorical:
+ temperature = self.algo_config.vae.prior.categorical_init_temp - epoch * self.algo_config.vae.prior.categorical_temp_anneal_step
+ temperature = max(temperature, self.algo_config.vae.prior.categorical_min_temp)
+ self.nets["policy"].set_gumbel_temperature(temperature)
+ return super(BC_VAE, self).train_on_batch(batch, epoch, validate=validate)
+
+ def _forward_training(self, batch):
+ """
+ Internal helper function for BC algo class. Compute forward pass
+ and return network outputs in @predictions dict.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ predictions (dict): dictionary containing network outputs
+ """
+ vae_inputs = dict(
+ actions=batch["actions"],
+ obs_dict=batch["obs"],
+ goal_dict=batch["goal_obs"],
+ freeze_encoder=batch.get("freeze_encoder", False),
+ )
+
+ vae_outputs = self.nets["policy"].forward_train(**vae_inputs)
+ predictions = OrderedDict(
+ actions=vae_outputs["decoder_outputs"],
+ kl_loss=vae_outputs["kl_loss"],
+ reconstruction_loss=vae_outputs["reconstruction_loss"],
+ encoder_z=vae_outputs["encoder_z"],
+ )
+ if not self.algo_config.vae.prior.use_categorical:
+ with torch.no_grad():
+ encoder_variance = torch.exp(vae_outputs["encoder_params"]["logvar"])
+ predictions["encoder_variance"] = encoder_variance
+ return predictions
+
+ def _compute_losses(self, predictions, batch):
+ """
+ Internal helper function for BC algo class. Compute losses based on
+ network outputs in @predictions dict, using reference labels in @batch.
+
+ Args:
+ predictions (dict): dictionary containing network outputs, from @_forward_training
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ losses (dict): dictionary of losses computed over the batch
+ """
+
+ # total loss is sum of reconstruction and KL, weighted by beta
+ kl_loss = predictions["kl_loss"]
+ recons_loss = predictions["reconstruction_loss"]
+ action_loss = recons_loss + self.algo_config.vae.kl_weight * kl_loss
+ return OrderedDict(
+ recons_loss=recons_loss,
+ kl_loss=kl_loss,
+ action_loss=action_loss,
+ )
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ log = PolicyAlgo.log_info(self, info)
+ log["Loss"] = info["losses"]["action_loss"].item()
+ log["KL_Loss"] = info["losses"]["kl_loss"].item()
+ log["Reconstruction_Loss"] = info["losses"]["recons_loss"].item()
+ if self.algo_config.vae.prior.use_categorical:
+ log["Gumbel_Temperature"] = self.nets["policy"].get_gumbel_temperature()
+ else:
+ log["Encoder_Variance"] = info["predictions"]["encoder_variance"].mean().item()
+ if "policy_grad_norms" in info:
+ log["Policy_Grad_Norms"] = info["policy_grad_norms"]
+ return log
+
+
+class BC_RNN(BC):
+ """
+ BC training with an RNN policy.
+ """
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ self.nets = nn.ModuleDict()
+ self.nets["policy"] = PolicyNets.RNNActorNetwork(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.actor_layer_dims,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ **BaseNets.rnn_args_from_config(self.algo_config.rnn),
+ )
+
+ self._rnn_hidden_state = None
+ self._rnn_horizon = self.algo_config.rnn.horizon
+ self._rnn_counter = 0
+ self._rnn_is_open_loop = self.algo_config.rnn.get("open_loop", False)
+
+ self.nets = self.nets.float().to(self.device)
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+ input_batch["obs"] = batch["obs"]
+ input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
+ input_batch["actions"] = batch["actions"]
+
+ if self._rnn_is_open_loop:
+ # replace the observation sequence with one that only consists of the first observation.
+ # This way, all actions are predicted "open-loop" after the first observation, based
+ # on the rnn hidden state.
+ n_steps = batch["actions"].shape[1]
+ obs_seq_start = TensorUtils.index_at_time(batch["obs"], ind=0)
+ input_batch["obs"] = TensorUtils.unsqueeze_expand_at(obs_seq_start, size=n_steps, dim=1)
+
+ # we move to device first before float conversion because image observation modalities will be uint8 -
+ # this minimizes the amount of data transferred to GPU
+ return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ assert not self.nets.training
+
+ if self._rnn_hidden_state is None or self._rnn_counter % self._rnn_horizon == 0:
+ batch_size = list(obs_dict.values())[0].shape[0]
+ self._rnn_hidden_state = self.nets["policy"].get_rnn_init_state(batch_size=batch_size, device=self.device)
+
+ if self._rnn_is_open_loop:
+ # remember the initial observation, and use it instead of the current observation
+ # for open-loop action sequence prediction
+ self._open_loop_obs = TensorUtils.clone(TensorUtils.detach(obs_dict))
+
+ obs_to_use = obs_dict
+ if self._rnn_is_open_loop:
+ # replace current obs with last recorded obs
+ obs_to_use = self._open_loop_obs
+
+ self._rnn_counter += 1
+ action, self._rnn_hidden_state = self.nets["policy"].forward_step(
+ obs_to_use, goal_dict=goal_dict, rnn_state=self._rnn_hidden_state)
+ return action
+
+ def reset(self):
+ """
+ Reset algo state to prepare for environment rollouts.
+ """
+ self._rnn_hidden_state = None
+ self._rnn_counter = 0
+
+
+class BC_RNN_GMM(BC_RNN):
+ """
+ BC training with an RNN GMM policy.
+ """
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ assert self.algo_config.gmm.enabled
+ assert self.algo_config.rnn.enabled
+
+ self.nets = nn.ModuleDict()
+ self.nets["policy"] = PolicyNets.RNNGMMActorNetwork(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.actor_layer_dims,
+ num_modes=self.algo_config.gmm.num_modes,
+ min_std=self.algo_config.gmm.min_std,
+ std_activation=self.algo_config.gmm.std_activation,
+ low_noise_eval=self.algo_config.gmm.low_noise_eval,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ **BaseNets.rnn_args_from_config(self.algo_config.rnn),
+ )
+
+ self._rnn_hidden_state = None
+ self._rnn_horizon = self.algo_config.rnn.horizon
+ self._rnn_counter = 0
+ self._rnn_is_open_loop = self.algo_config.rnn.get("open_loop", False)
+
+ self.nets = self.nets.float().to(self.device)
+
+ def _forward_training(self, batch):
+ """
+ Internal helper function for BC algo class. Compute forward pass
+ and return network outputs in @predictions dict.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ predictions (dict): dictionary containing network outputs
+ """
+ dists = self.nets["policy"].forward_train(
+ obs_dict=batch["obs"],
+ goal_dict=batch["goal_obs"],
+ )
+
+ # make sure that this is a batch of multivariate action distributions, so that
+ # the log probability computation will be correct
+ assert len(dists.batch_shape) == 2 # [B, T]
+ log_probs = dists.log_prob(batch["actions"])
+
+ predictions = OrderedDict(
+ log_probs=log_probs,
+ )
+ return predictions
+
+ def _compute_losses(self, predictions, batch):
+ """
+ Internal helper function for BC algo class. Compute losses based on
+ network outputs in @predictions dict, using reference labels in @batch.
+
+ Args:
+ predictions (dict): dictionary containing network outputs, from @_forward_training
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ losses (dict): dictionary of losses computed over the batch
+ """
+
+ # loss is just negative log-likelihood of action targets
+ action_loss = -predictions["log_probs"].mean()
+ return OrderedDict(
+ log_probs=-action_loss,
+ action_loss=action_loss,
+ )
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ log = PolicyAlgo.log_info(self, info)
+ log["Loss"] = info["losses"]["action_loss"].item()
+ log["Log_Likelihood"] = info["losses"]["log_probs"].item()
+ if "policy_grad_norms" in info:
+ log["Policy_Grad_Norms"] = info["policy_grad_norms"]
+ return log
+
+
+class BC_Transformer(BC):
+ """
+ BC training with a Transformer policy.
+ """
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ assert self.algo_config.transformer.enabled
+
+ self.nets = nn.ModuleDict()
+ self.nets["policy"] = PolicyNets.TransformerActorNetwork(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ **BaseNets.transformer_args_from_config(self.algo_config.transformer),
+ )
+ self._set_params_from_config()
+ self.nets = self.nets.float().to(self.device)
+
+ def _set_params_from_config(self):
+ """
+ Read specific config variables we need for training / eval.
+ Called by @_create_networks method
+ """
+ self.context_length = self.algo_config.transformer.context_length
+ self.supervise_all_steps = self.algo_config.transformer.supervise_all_steps
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+ h = self.context_length
+ input_batch["obs"] = {k: batch["obs"][k][:, :h, :] for k in batch["obs"]}
+ input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
+
+ if self.supervise_all_steps:
+ # supervision on entire sequence (instead of just current timestep)
+ input_batch["actions"] = batch["actions"][:, :h, :]
+ else:
+ # just use current timestep
+ input_batch["actions"] = batch["actions"][:, h-1, :]
+
+ input_batch = TensorUtils.to_device(TensorUtils.to_float(input_batch), self.device)
+ return input_batch
+
+ def _forward_training(self, batch, epoch=None):
+ """
+ Internal helper function for BC_Transformer algo class. Compute forward pass
+ and return network outputs in @predictions dict.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ predictions (dict): dictionary containing network outputs
+ """
+ # ensure that transformer context length is consistent with temporal dimension of observations
+ TensorUtils.assert_size_at_dim(
+ batch["obs"],
+ size=(self.context_length),
+ dim=1,
+ msg="Error: expect temporal dimension of obs batch to match transformer context length {}".format(self.context_length),
+ )
+
+ predictions = OrderedDict()
+ predictions["actions"] = self.nets["policy"](obs_dict=batch["obs"], actions=None, goal_dict=batch["goal_obs"])
+ if not self.supervise_all_steps:
+ # only supervise final timestep
+ predictions["actions"] = predictions["actions"][:, -1, :]
+ return predictions
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ assert not self.nets.training
+
+ return self.nets["policy"](obs_dict, actions=None, goal_dict=goal_dict)[:, -1, :]
+
+
+class BC_Transformer_GMM(BC_Transformer):
+ """
+ BC training with a Transformer GMM policy.
+ """
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ assert self.algo_config.gmm.enabled
+ assert self.algo_config.transformer.enabled
+
+ self.nets = nn.ModuleDict()
+ self.nets["policy"] = PolicyNets.TransformerGMMActorNetwork(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ num_modes=self.algo_config.gmm.num_modes,
+ min_std=self.algo_config.gmm.min_std,
+ std_activation=self.algo_config.gmm.std_activation,
+ low_noise_eval=self.algo_config.gmm.low_noise_eval,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ **BaseNets.transformer_args_from_config(self.algo_config.transformer),
+ )
+ self._set_params_from_config()
+ self.nets = self.nets.float().to(self.device)
+
+ def _forward_training(self, batch, epoch=None):
+ """
+ Modify from super class to support GMM training.
+ """
+ # ensure that transformer context length is consistent with temporal dimension of observations
+ TensorUtils.assert_size_at_dim(
+ batch["obs"],
+ size=(self.context_length),
+ dim=1,
+ msg="Error: expect temporal dimension of obs batch to match transformer context length {}".format(self.context_length),
+ )
+
+ dists = self.nets["policy"].forward_train(
+ obs_dict=batch["obs"],
+ actions=None,
+ goal_dict=batch["goal_obs"],
+ low_noise_eval=False,
+ )
+
+ # make sure that this is a batch of multivariate action distributions, so that
+ # the log probability computation will be correct
+ assert len(dists.batch_shape) == 2 # [B, T]
+
+ if not self.supervise_all_steps:
+ # only use final timestep prediction by making a new distribution with only final timestep.
+ # This essentially does `dists = dists[:, -1]`
+ component_distribution = D.Normal(
+ loc=dists.component_distribution.base_dist.loc[:, -1],
+ scale=dists.component_distribution.base_dist.scale[:, -1],
+ )
+ component_distribution = D.Independent(component_distribution, 1)
+ mixture_distribution = D.Categorical(logits=dists.mixture_distribution.logits[:, -1])
+ dists = D.MixtureSameFamily(
+ mixture_distribution=mixture_distribution,
+ component_distribution=component_distribution,
+ )
+
+ log_probs = dists.log_prob(batch["actions"])
+
+ predictions = OrderedDict(
+ log_probs=log_probs,
+ )
+ return predictions
+
+ def _compute_losses(self, predictions, batch):
+ """
+ Internal helper function for BC_Transformer_GMM algo class. Compute losses based on
+ network outputs in @predictions dict, using reference labels in @batch.
+ Args:
+ predictions (dict): dictionary containing network outputs, from @_forward_training
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+ Returns:
+ losses (dict): dictionary of losses computed over the batch
+ """
+
+ # loss is just negative log-likelihood of action targets
+ action_loss = -predictions["log_probs"].mean()
+ return OrderedDict(
+ log_probs=-action_loss,
+ action_loss=action_loss,
+ )
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+ Args:
+ info (dict): dictionary of info
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ log = PolicyAlgo.log_info(self, info)
+ log["Loss"] = info["losses"]["action_loss"].item()
+ log["Log_Likelihood"] = info["losses"]["log_probs"].item()
+ if "policy_grad_norms" in info:
+ log["Policy_Grad_Norms"] = info["policy_grad_norms"]
+ return log
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/bcq.py b/phantom/submodules/phantom-robomimic/robomimic/algo/bcq.py
new file mode 100644
index 0000000000000000000000000000000000000000..5843ccb5bd594c596a8dc138eab863bb3f5e3550
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/bcq.py
@@ -0,0 +1,1022 @@
+"""
+Batch-Constrained Q-Learning (BCQ), with support for more general
+generative action models (the original paper uses a cVAE).
+(Paper - https://arxiv.org/abs/1812.02900).
+"""
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import robomimic.models.obs_nets as ObsNets
+import robomimic.models.policy_nets as PolicyNets
+import robomimic.models.value_nets as ValueNets
+import robomimic.models.vae_nets as VAENets
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.loss_utils as LossUtils
+
+from robomimic.algo import register_algo_factory_func, PolicyAlgo, ValueAlgo
+
+
+@register_algo_factory_func("bcq")
+def algo_config_to_class(algo_config):
+ """
+ Maps algo config to the BCQ algo class to instantiate, along with additional algo kwargs.
+
+ Args:
+ algo_config (Config instance): algo config
+
+ Returns:
+ algo_class: subclass of Algo
+ algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
+ """
+ if algo_config.critic.distributional.enabled:
+ return BCQ_Distributional, {}
+ if algo_config.action_sampler.gmm.enabled:
+ return BCQ_GMM, {}
+ assert algo_config.action_sampler.vae.enabled
+ return BCQ, {}
+
+
+class BCQ(PolicyAlgo, ValueAlgo):
+ """
+ Default BCQ training, based on https://arxiv.org/abs/1812.02900 and
+ https://github.com/sfujim/BCQ
+ """
+ def __init__(self, **kwargs):
+ PolicyAlgo.__init__(self, **kwargs)
+
+ # save the discount factor - it may be overriden later
+ self.set_discount(self.algo_config.discount)
+
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ self.nets = nn.ModuleDict()
+
+ self._create_critics()
+ self._create_action_sampler()
+ if self.algo_config.actor.enabled:
+ self._create_actor()
+
+ # sync target networks at beginning of training
+ with torch.no_grad():
+ for critic_ind in range(len(self.nets["critic"])):
+ TorchUtils.hard_update(
+ source=self.nets["critic"][critic_ind],
+ target=self.nets["critic_target"][critic_ind],
+ )
+
+ if self.algo_config.actor.enabled:
+ TorchUtils.hard_update(
+ source=self.nets["actor"],
+ target=self.nets["actor_target"],
+ )
+
+ self.nets = self.nets.float().to(self.device)
+
+ def _create_critics(self):
+ """
+ Called in @_create_networks to make critic networks.
+ """
+ critic_class = ValueNets.ActionValueNetwork
+ critic_args = dict(
+ obs_shapes=self.obs_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.critic.layer_dims,
+ value_bounds=self.algo_config.critic.value_bounds,
+ goal_shapes=self.goal_shapes,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ # Q network ensemble and target ensemble
+ self.nets["critic"] = nn.ModuleList()
+ self.nets["critic_target"] = nn.ModuleList()
+ for _ in range(self.algo_config.critic.ensemble.n):
+ critic = critic_class(**critic_args)
+ self.nets["critic"].append(critic)
+
+ critic_target = critic_class(**critic_args)
+ self.nets["critic_target"].append(critic_target)
+
+ def _create_action_sampler(self):
+ """
+ Called in @_create_networks to make action sampler network.
+ """
+
+ # VAE network for approximate sampling from batch dataset
+ assert self.algo_config.action_sampler.vae.enabled
+ self.nets["action_sampler"] = PolicyNets.VAEActor(
+ obs_shapes=self.obs_shapes,
+ ac_dim=self.ac_dim,
+ device=self.device,
+ goal_shapes=self.goal_shapes,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ **VAENets.vae_args_from_config(self.algo_config.action_sampler.vae),
+ )
+
+ def _create_actor(self):
+ """
+ Called in @_create_networks to make actor network.
+ """
+ assert self.algo_config.actor.enabled
+ actor_class = PolicyNets.PerturbationActorNetwork
+ actor_args = dict(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.actor.layer_dims,
+ perturbation_scale=self.algo_config.actor.perturbation_scale,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ self.nets["actor"] = actor_class(**actor_args)
+ self.nets["actor_target"] = actor_class(**actor_args)
+
+ def _check_epoch(self, net_name, epoch):
+ """
+ Helper function to check whether backprop should happen this epoch.
+
+ Args:
+ net_name (str): name of network in @self.nets and @self.optim_params
+ epoch (int): epoch number
+ """
+ epoch_start_check = (self.optim_params[net_name]["start_epoch"] == -1) or (epoch >= self.optim_params[net_name]["start_epoch"])
+ epoch_end_check = (self.optim_params[net_name]["end_epoch"] == -1) or (epoch < self.optim_params[net_name]["end_epoch"])
+ return (epoch_start_check and epoch_end_check)
+
+ def set_discount(self, discount):
+ """
+ Useful function to modify discount factor if necessary (e.g. for n-step returns).
+ """
+ self.discount = discount
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+
+ # n-step returns (default is 1)
+ n_step = self.algo_config.n_step
+ assert batch["actions"].shape[1] >= n_step
+
+ # remove temporal batches for all
+ input_batch["obs"] = {k: batch["obs"][k][:, 0, :] for k in batch["obs"]}
+ input_batch["next_obs"] = {k: batch["next_obs"][k][:, n_step - 1, :] for k in batch["next_obs"]}
+ input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
+ input_batch["actions"] = batch["actions"][:, 0, :]
+
+ # note: ensure scalar signals (rewards, done) retain last dimension of 1 to be compatible with model outputs
+
+ # single timestep reward is discounted sum of intermediate rewards in sequence
+ reward_seq = batch["rewards"][:, :n_step]
+ discounts = torch.pow(self.algo_config.discount, torch.arange(n_step).float()).unsqueeze(0)
+ input_batch["rewards"] = (reward_seq * discounts).sum(dim=1).unsqueeze(1)
+
+ # discount rate will be gamma^N for computing n-step returns
+ new_discount = (self.algo_config.discount ** n_step)
+ self.set_discount(new_discount)
+
+ # consider this n-step seqeunce done if any intermediate dones are present
+ done_seq = batch["dones"][:, :n_step]
+ input_batch["dones"] = (done_seq.sum(dim=1) > 0).float().unsqueeze(1)
+
+ if self.algo_config.infinite_horizon:
+ # scale terminal rewards by 1 / (1 - gamma) for infinite horizon MDPs
+ done_inds = input_batch["dones"].round().long().nonzero(as_tuple=False)[:, 0]
+ if done_inds.shape[0] > 0:
+ input_batch["rewards"][done_inds] = input_batch["rewards"][done_inds] * (1. / (1. - self.discount))
+
+ # we move to device first before float conversion because image observation modalities will be uint8 -
+ # this minimizes the amount of data transferred to GPU
+ return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
+
+ def _train_action_sampler_on_batch(self, batch, epoch, no_backprop=False):
+ """
+ A modular helper function that can be overridden in case
+ subclasses would like to modify training behavior for the
+ action sampler.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ no_backprop (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ outputs (dict): dictionary of outputs to use during critic training
+ (for computing target values)
+ """
+ info = OrderedDict()
+ if self.algo_config.action_sampler.vae.prior.use_categorical:
+ temperature = self.algo_config.action_sampler.vae.prior.categorical_init_temp - epoch * self.algo_config.action_sampler.vae.prior.categorical_temp_anneal_step
+ temperature = max(temperature, self.algo_config.action_sampler.vae.prior.categorical_min_temp)
+ self.nets["action_sampler"].set_gumbel_temperature(temperature)
+
+ vae_inputs = dict(
+ actions=batch["actions"],
+ obs_dict=batch["obs"],
+ goal_dict=batch["goal_obs"],
+ )
+
+ # maybe freeze encoder weights
+ if (self.algo_config.action_sampler.freeze_encoder_epoch != -1) and (epoch >= self.algo_config.action_sampler.freeze_encoder_epoch):
+ vae_inputs["freeze_encoder"] = True
+
+ # VAE forward
+ vae_outputs = self.nets["action_sampler"].forward_train(**vae_inputs)
+ recons_loss = vae_outputs["reconstruction_loss"]
+ kl_loss = vae_outputs["kl_loss"]
+ vae_loss = recons_loss + self.algo_config.action_sampler.vae.kl_weight * kl_loss
+ info["action_sampler/loss"] = vae_loss
+ info["action_sampler/recons_loss"] = recons_loss
+ info["action_sampler/kl_loss"] = kl_loss
+ if not self.algo_config.action_sampler.vae.prior.use_categorical:
+ with torch.no_grad():
+ encoder_variance = torch.exp(vae_outputs["encoder_params"]["logvar"]).mean()
+ info["action_sampler/encoder_variance"] = encoder_variance
+ outputs = TensorUtils.detach(vae_outputs)
+
+ # VAE gradient step
+ if not no_backprop:
+ vae_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["action_sampler"],
+ optim=self.optimizers["action_sampler"],
+ loss=vae_loss,
+ )
+ info["action_sampler/grad_norms"] = vae_grad_norms
+ return info, outputs
+
+ def _train_critic_on_batch(self, batch, action_sampler_outputs, epoch, no_backprop=False):
+ """
+ A modular helper function that can be overridden in case
+ subclasses would like to modify training behavior for the
+ critics.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ action_sampler_outputs (dict): dictionary of outputs from the action sampler. Used
+ to form target values for training the critic
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ no_backprop (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ critic_outputs (dict): dictionary of critic outputs - useful for
+ logging purposes
+ """
+ info = OrderedDict()
+
+ # batch variables
+ s_batch = batch["obs"]
+ a_batch = batch["actions"]
+ r_batch = batch["rewards"]
+ ns_batch = batch["next_obs"]
+ goal_s_batch = batch["goal_obs"]
+
+ # 1 if not done, 0 otherwise
+ done_mask_batch = 1. - batch["dones"]
+ info["done_masks"] = done_mask_batch
+
+ # Bellman backup for Q-targets
+ q_targets = self._get_target_values(
+ next_states=ns_batch,
+ goal_states=goal_s_batch,
+ rewards=r_batch,
+ dones=done_mask_batch,
+ action_sampler_outputs=action_sampler_outputs,
+ )
+ info["critic/q_targets"] = q_targets
+
+ # Train all critics using this set of targets for regression
+ critic_outputs = []
+ for critic_ind, critic in enumerate(self.nets["critic"]):
+ critic_loss, critic_output = self._compute_critic_loss(
+ critic=critic,
+ states=s_batch,
+ actions=a_batch,
+ goal_states=goal_s_batch,
+ q_targets=q_targets,
+ )
+ info["critic/critic{}_loss".format(critic_ind + 1)] = critic_loss
+ critic_outputs.append(critic_output)
+
+ if not no_backprop:
+ critic_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["critic"][critic_ind],
+ optim=self.optimizers["critic"][critic_ind],
+ loss=critic_loss,
+ max_grad_norm=self.algo_config.critic.max_gradient_norm,
+ )
+ info["critic/critic{}_grad_norms".format(critic_ind + 1)] = critic_grad_norms
+
+ return info, critic_outputs
+
+ def _train_actor_on_batch(self, batch, action_sampler_outputs, critic_outputs, epoch, no_backprop=False):
+ """
+ A modular helper function that can be overridden in case
+ subclasses would like to modify training behavior for the
+ perturbation actor.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ action_sampler_outputs (dict): dictionary of outputs from the action sampler. Currently
+ unused, although more sophisticated models may use it.
+
+ critic_outputs (dict): dictionary of outputs from the critic. Currently
+ unused, although more sophisticated models may use it.
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ no_backprop (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ assert self.algo_config.actor.enabled
+
+ info = OrderedDict()
+
+ # Actor loss (update with DDPG loss)
+ s_batch = batch["obs"]
+ goal_s_batch = batch["goal_obs"]
+
+ # sample some actions from action sampler and perturb them, then improve perturbations
+ # where improvement is measured by the critic
+ sampled_actions = self.nets["action_sampler"](s_batch, goal_s_batch).detach() # don't backprop into samples
+ perturbed_actions = self.nets["actor"](s_batch, sampled_actions, goal_s_batch)
+ actor_loss = -(self.nets["critic"][0](s_batch, perturbed_actions, goal_s_batch)).mean()
+ info["actor/loss"] = actor_loss
+
+ if not no_backprop:
+ actor_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["actor"],
+ optim=self.optimizers["actor"],
+ loss=actor_loss,
+ )
+ info["actor/grad_norms"] = actor_grad_norms
+
+ return info
+
+ def _get_target_values(self, next_states, goal_states, rewards, dones, action_sampler_outputs=None):
+ """
+ Helper function to get target values for training Q-function with TD-loss.
+
+ Args:
+ next_states (dict): batch of next observations
+ goal_states (dict): if not None, batch of goal observations
+ rewards (torch.Tensor): batch of rewards - should be shape (B, 1)
+ dones (torch.Tensor): batch of done signals - should be shape (B, 1)
+ action_sampler_outputs (dict): dictionary of outputs from the action sampler. Currently
+ unused, although more sophisticated models may use it.
+
+ Returns:
+ q_targets (torch.Tensor): target Q-values to use for TD loss
+ """
+
+ with torch.no_grad():
+ # we need to stack the observations with redundancy @num_action_samples here, then decode
+ # to get all sampled actions. for example, if we generate 2 samples per observation and
+ # the batch size is 3, then ob_tiled = [ob1; ob1; ob2; ob2; ob3; ob3]
+ next_states_tiled = ObsUtils.repeat_and_stack_observation(next_states, n=self.algo_config.critic.num_action_samples)
+ goal_states_tiled = None
+ if len(self.goal_shapes) > 0:
+ goal_states_tiled = ObsUtils.repeat_and_stack_observation(goal_states, n=self.algo_config.critic.num_action_samples)
+
+ # sample action proposals
+ next_sampled_actions = self._sample_actions_for_value_maximization(
+ states_tiled=next_states_tiled,
+ goal_states_tiled=goal_states_tiled,
+ for_target_update=True,
+ )
+
+ q_targets = self._get_target_values_from_sampled_actions(
+ next_states_tiled=next_states_tiled,
+ next_sampled_actions=next_sampled_actions,
+ goal_states_tiled=goal_states_tiled,
+ rewards=rewards,
+ dones=dones,
+ )
+
+ return q_targets
+
+ def _sample_actions_for_value_maximization(self, states_tiled, goal_states_tiled, for_target_update):
+ """
+ Helper function to sample actions for maximization (the "batch-constrained" part of
+ batch-constrained q-learning).
+
+ Args:
+ states_tiled (dict): observations to use for sampling actions. Assumes that tiling
+ has already occurred - so that if the batch size is B, and N samples are
+ desired for each observation in the batch, the leading dimension for each
+ observation in the dict is B * N
+
+ goal_states_tiled (dict): if not None, goal observations
+
+ for_target_update (bool): if True, actions are being sampled for use in training the
+ critic - which means the target actor network should be used
+
+ Returns:
+ sampled_actions (torch.Tensor): actions sampled from the action sampler, and maybe
+ perturbed by the actor network
+ """
+
+ with torch.no_grad():
+ sampled_actions = self.nets["action_sampler"](states_tiled, goal_states_tiled)
+ if self.algo_config.actor.enabled:
+ actor = self.nets["actor"]
+ if for_target_update:
+ actor = self.nets["actor_target"]
+ # perturb the actions with the policy
+ sampled_actions = actor(states_tiled, sampled_actions, goal_states_tiled)
+
+ return sampled_actions
+
+ def _get_target_values_from_sampled_actions(self, next_states_tiled, next_sampled_actions, goal_states_tiled, rewards, dones):
+ """
+ Helper function to get target values for training Q-function with TD-loss. The function
+ assumes that action candidates to maximize over have already been computed, and that
+ the input states have been tiled (repeated) to be compatible with the sampled actions.
+
+ Args:
+ next_states_tiled (dict): next observations to use for sampling actions. Assumes that
+ tiling has already occurred - so that if the batch size is B, and N samples are
+ desired for each observation in the batch, the leading dimension for each
+ observation in the dict is B * N
+
+ next_sampled_actions (torch.Tensor): actions sampled from the action sampler. This function
+ will maximize the critic over these action candidates (using the TD3 trick)
+
+ goal_states_tiled (dict): if not None, goal observations
+
+ rewards (torch.Tensor): batch of rewards - should be shape (B, 1)
+
+ dones (torch.Tensor): batch of done signals - should be shape (B, 1)
+
+ Returns:
+ q_targets (torch.Tensor): target Q-values to use for TD loss
+ """
+ with torch.no_grad():
+ # feed tiled observations and sampled actions into the critics and then
+ # reshape to get all Q-values in second dimension per observation in batch.
+ all_value_targets = self.nets["critic_target"][0](next_states_tiled, next_sampled_actions, goal_states_tiled).reshape(
+ -1, self.algo_config.critic.num_action_samples)
+ max_value_targets = all_value_targets
+ min_value_targets = all_value_targets
+
+ # TD3 trick to combine max and min over all Q-ensemble estimates into single target estimates
+ for critic_target in self.nets["critic_target"][1:]:
+ all_value_targets = critic_target(next_states_tiled, next_sampled_actions, goal_states_tiled).reshape(
+ -1, self.algo_config.critic.num_action_samples)
+ max_value_targets = torch.max(max_value_targets, all_value_targets)
+ min_value_targets = torch.min(min_value_targets, all_value_targets)
+ all_value_targets = self.algo_config.critic.ensemble.weight * min_value_targets + \
+ (1. - self.algo_config.critic.ensemble.weight) * max_value_targets
+
+ # take maximum over all sampled action values per observation and compute targets
+ value_targets = torch.max(all_value_targets, dim=1, keepdim=True)[0]
+ q_targets = rewards + dones * self.discount * value_targets
+
+ return q_targets
+
+ def _compute_critic_loss(self, critic, states, actions, goal_states, q_targets):
+ """
+ Helper function to compute loss between estimated Q-values and target Q-values.
+ It should also return outputs needed for downstream training (for training the
+ actor).
+
+ Args:
+ critic (torch.nn.Module): critic network
+ states (dict): batch of observations
+ actions (torch.Tensor): batch of actions
+ goal_states (dict): if not None, batch of goal observations
+ q_targets (torch.Tensor): batch of target q-values for the TD loss
+
+ Returns:
+ critic_loss (torch.Tensor): critic loss
+ critic_output (dict): additional outputs from the critic. This function
+ returns None, but subclasses may want to provide some information
+ here.
+ """
+ q_estimated = critic(states, actions, goal_states)
+ if self.algo_config.critic.use_huber:
+ critic_loss = nn.SmoothL1Loss()(q_estimated, q_targets)
+ else:
+ critic_loss = nn.MSELoss()(q_estimated, q_targets)
+ return critic_loss, None
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ with TorchUtils.maybe_no_grad(no_grad=validate):
+ info = PolicyAlgo.train_on_batch(self, batch, epoch, validate=validate)
+
+ # Action Sampler training
+ no_action_sampler_backprop = validate or (not self._check_epoch(net_name="action_sampler", epoch=epoch))
+ with TorchUtils.maybe_no_grad(no_grad=no_action_sampler_backprop):
+ action_sampler_info, action_sampler_outputs = self._train_action_sampler_on_batch(
+ batch=batch,
+ epoch=epoch,
+ no_backprop=no_action_sampler_backprop,
+ )
+ info.update(action_sampler_info)
+
+ # make sure action sampler is in eval mode for models like GMM which may require low-noise
+ # samples when sampling actions.
+ self.nets["action_sampler"].eval()
+
+ # Critic training
+ no_critic_backprop = validate or (not self._check_epoch(net_name="critic", epoch=epoch))
+ with TorchUtils.maybe_no_grad(no_grad=no_critic_backprop):
+ critic_info, critic_outputs = self._train_critic_on_batch(
+ batch=batch,
+ action_sampler_outputs=action_sampler_outputs,
+ epoch=epoch,
+ no_backprop=no_critic_backprop,
+ )
+ info.update(critic_info)
+
+ if self.algo_config.actor.enabled:
+ # Actor training
+ no_actor_backprop = validate or (not self._check_epoch(net_name="actor", epoch=epoch))
+ with TorchUtils.maybe_no_grad(no_grad=no_actor_backprop):
+ actor_info = self._train_actor_on_batch(
+ batch=batch,
+ action_sampler_outputs=action_sampler_outputs,
+ critic_outputs=critic_outputs,
+ epoch=epoch,
+ no_backprop=no_actor_backprop,
+ )
+ info.update(actor_info)
+
+ if not validate:
+ # restore to train mode if necessary
+ self.nets["action_sampler"].train()
+
+ # update the target critic networks (only when critic has gradient update)
+ if not no_critic_backprop:
+ with torch.no_grad():
+ for critic_ind in range(len(self.nets["critic"])):
+ TorchUtils.soft_update(
+ source=self.nets["critic"][critic_ind],
+ target=self.nets["critic_target"][critic_ind],
+ tau=self.algo_config.target_tau,
+ )
+
+ # update target actor network (only when actor has gradient update)
+ if self.algo_config.actor.enabled and (not no_actor_backprop):
+ with torch.no_grad():
+ TorchUtils.soft_update(
+ source=self.nets["actor"],
+ target=self.nets["actor_target"],
+ tau=self.algo_config.target_tau,
+ )
+
+ return info
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ loss_log = OrderedDict()
+
+ # record current optimizer learning rates
+ for k in self.optimizers:
+ keys = [k]
+ optims = [self.optimizers[k]]
+ if k == "critic":
+ # account for critic having one optimizer per ensemble member
+ keys = ["{}{}".format(k, critic_ind) for critic_ind in range(len(self.nets["critic"]))]
+ optims = self.optimizers[k]
+ for kp, optimizer in zip(keys, optims):
+ for i, param_group in enumerate(optimizer.param_groups):
+ loss_log["Optimizer/{}{}_lr".format(kp, i)] = param_group["lr"]
+
+ # extract relevant logs for action sampler, critic, and actor
+ loss_log["Loss"] = 0.
+ for loss_logger in [self._log_action_sampler_info, self._log_critic_info, self._log_actor_info]:
+ this_log = loss_logger(info)
+ if "Loss" in this_log:
+ # manually merge total loss
+ loss_log["Loss"] += this_log["Loss"]
+ del this_log["Loss"]
+ loss_log.update(this_log)
+
+ return loss_log
+
+ def _log_action_sampler_info(self, info):
+ """
+ Helper function to extract action sampler-relevant information for logging.
+ """
+ loss_log = OrderedDict()
+ loss_log["Action_Sampler/Loss"] = info["action_sampler/loss"].item()
+ loss_log["Action_Sampler/Reconsruction_Loss"] = info["action_sampler/recons_loss"].item()
+ loss_log["Action_Sampler/KL_Loss"] = info["action_sampler/kl_loss"].item()
+ if self.algo_config.action_sampler.vae.prior.use_categorical:
+ loss_log["Action_Sampler/Gumbel_Temperature"] = self.nets["action_sampler"].get_gumbel_temperature()
+ else:
+ loss_log["Action_Sampler/Encoder_Variance"] = info["action_sampler/encoder_variance"].item()
+ if "action_sampler/grad_norms" in info:
+ loss_log["Action_Sampler/Grad_Norms"] = info["action_sampler/grad_norms"]
+ loss_log["Loss"] = loss_log["Action_Sampler/Loss"]
+ return loss_log
+
+ def _log_critic_info(self, info):
+ """
+ Helper function to extract critic-relevant information for logging.
+ """
+ loss_log = OrderedDict()
+ if "done_masks" in info:
+ loss_log["Critic/Done_Mask_Percentage"] = 100. * torch.mean(info["done_masks"]).item()
+ if "critic/q_targets" in info:
+ loss_log["Critic/Q_Targets"] = info["critic/q_targets"].mean().item()
+ loss_log["Loss"] = 0.
+ for critic_ind in range(len(self.nets["critic"])):
+ loss_log["Critic/Critic{}_Loss".format(critic_ind + 1)] = info["critic/critic{}_loss".format(critic_ind + 1)].item()
+ if "critic/critic{}_grad_norms".format(critic_ind + 1) in info:
+ loss_log["Critic/Critic{}_Grad_Norms".format(critic_ind + 1)] = info["critic/critic{}_grad_norms".format(critic_ind + 1)]
+ loss_log["Loss"] += loss_log["Critic/Critic{}_Loss".format(critic_ind + 1)]
+ return loss_log
+
+ def _log_actor_info(self, info):
+ """
+ Helper function to extract actor-relevant information for logging.
+ """
+ loss_log = OrderedDict()
+ if self.algo_config.actor.enabled:
+ loss_log["Actor/Loss"] = info["actor/loss"].item()
+ if "actor/grad_norms" in info:
+ loss_log["Actor/Grad_Norms"] = info["actor/grad_norms"]
+ loss_log["Loss"] = loss_log["Actor/Loss"]
+ return loss_log
+
+ def set_train(self):
+ """
+ Prepare networks for evaluation. Update from super class to make sure
+ target networks stay in evaluation mode all the time.
+ """
+ self.nets.train()
+
+ # target networks always in eval
+ for critic_ind in range(len(self.nets["critic_target"])):
+ self.nets["critic_target"][critic_ind].eval()
+
+ if self.algo_config.actor.enabled:
+ self.nets["actor_target"].eval()
+
+ def on_epoch_end(self, epoch):
+ """
+ Called at the end of each epoch.
+ """
+
+ # LR scheduling updates
+ for lr_sc in self.lr_schedulers["critic"]:
+ if lr_sc is not None:
+ lr_sc.step()
+
+ if self.lr_schedulers["action_sampler"] is not None:
+ self.lr_schedulers["action_sampler"].step()
+
+ if self.algo_config.actor.enabled and self.lr_schedulers["actor"] is not None:
+ self.lr_schedulers["actor"].step()
+
+ def _get_best_value(self, obs_dict, goal_dict=None):
+ """
+ Internal helper function for getting the best value for a given state and
+ the corresponding best action. Meant to be used at test-time. Key differences
+ between this and retrieving target values at train-time are that (1) only a
+ single critic is used for the value estimate and (2) the critic and actor
+ are used instead of the target critic and target actor.
+
+ Args:
+ obs_dict (dict): batch of current observations
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ best_value (torch.Tensor): best values
+ best_action (torch.Tensor): best actions
+ """
+ assert not self.nets.training
+
+ random_key = list(obs_dict.keys())[0]
+ batch_size = obs_dict[random_key].shape[0]
+
+ # number of action proposals from action sampler
+ num_action_samples = self.algo_config.critic.num_action_samples_rollout
+
+ # we need to stack the observations with redundancy @num_action_samples here, then decode
+ # to get all sampled actions. for example, if we generate 2 samples per observation and
+ # the batch size is 3, then ob_tiled = [ob1; ob1; ob2; ob2; ob3; ob3]
+ ob_tiled = ObsUtils.repeat_and_stack_observation(obs_dict, n=num_action_samples)
+ goal_tiled = None
+ if len(self.goal_shapes) > 0:
+ goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_action_samples)
+
+ sampled_actions = self._sample_actions_for_value_maximization(
+ states_tiled=ob_tiled,
+ goal_states_tiled=goal_tiled,
+ for_target_update=False,
+ )
+
+ # feed tiled observations and perturbed sampled actions into the critic and then
+ # reshape to get all Q-values in second dimension per observation in batch.
+ # finally, just take a maximum across that second dimension to take the best sampled action
+ all_critic_values = self.nets["critic"][0](ob_tiled, sampled_actions, goal_tiled).reshape(-1, num_action_samples)
+ best_action_index = torch.argmax(all_critic_values, dim=1)
+
+ all_actions = sampled_actions.reshape(batch_size, num_action_samples, -1)
+ best_action = all_actions[torch.arange(all_actions.shape[0]), best_action_index]
+ best_value = all_critic_values[torch.arange(all_critic_values.shape[0]), best_action_index].unsqueeze(1)
+
+ return best_value, best_action
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ assert not self.nets.training
+
+ _, best_action = self._get_best_value(obs_dict=obs_dict, goal_dict=goal_dict)
+ return best_action
+
+ def get_state_value(self, obs_dict, goal_dict=None):
+ """
+ Get state value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ assert not self.nets.training
+
+ best_value, _ = self._get_best_value(obs_dict=obs_dict, goal_dict=goal_dict)
+ return best_value
+
+ def get_state_action_value(self, obs_dict, actions, goal_dict=None):
+ """
+ Get state-action value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ actions (torch.Tensor): action
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ assert not self.nets.training
+
+ return self.nets["critic"][0](obs_dict, actions, goal_dict)
+
+
+class BCQ_GMM(BCQ):
+ """
+ A simple modification to BCQ that replaces the VAE used to sample action proposals from the
+ batch with a GMM.
+ """
+ def _create_action_sampler(self):
+ """
+ Called in @_create_networks to make action sampler network.
+ """
+ assert self.algo_config.action_sampler.gmm.enabled
+
+ # GMM network for approximate sampling from batch dataset
+ self.nets["action_sampler"] = PolicyNets.GMMActorNetwork(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.action_sampler.actor_layer_dims,
+ num_modes=self.algo_config.action_sampler.gmm.num_modes,
+ min_std=self.algo_config.action_sampler.gmm.min_std,
+ std_activation=self.algo_config.action_sampler.gmm.std_activation,
+ low_noise_eval=self.algo_config.action_sampler.gmm.low_noise_eval,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ def _train_action_sampler_on_batch(self, batch, epoch, no_backprop=False):
+ """
+ Modify this helper function from superclass to train GMM action sampler
+ with maximum likelihood.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ no_backprop (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ outputs (dict): dictionary of outputs to use during critic training
+ (for computing target values)
+ """
+ info = OrderedDict()
+
+ # GMM forward
+ dists = self.nets["action_sampler"].forward_train(
+ obs_dict=batch["obs"],
+ goal_dict=batch["goal_obs"],
+ )
+
+ # make sure that this is a batch of multivariate action distributions, so that
+ # the log probability computation will be correct
+ assert len(dists.batch_shape) == 1
+ log_probs = dists.log_prob(batch["actions"])
+ loss = -log_probs.mean()
+ info["action_sampler/loss"] = loss
+
+ # GMM gradient step
+ if not no_backprop:
+ gmm_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["action_sampler"],
+ optim=self.optimizers["action_sampler"],
+ loss=loss,
+ )
+ info["action_sampler/grad_norms"] = gmm_grad_norms
+ return info, None
+
+ def _log_action_sampler_info(self, info):
+ """
+ Update from superclass for GMM (no KL loss).
+ """
+ loss_log = OrderedDict()
+ loss_log["Action_Sampler/Loss"] = info["action_sampler/loss"].item()
+ if "action_sampler/grad_norms" in info:
+ loss_log["Action_Sampler/Grad_Norms"] = info["action_sampler/grad_norms"]
+ loss_log["Loss"] = loss_log["Action_Sampler/Loss"]
+ return loss_log
+
+
+class BCQ_Distributional(BCQ):
+ """
+ BCQ with distributional critics. Distributional critics output categorical
+ distributions over a discrete set of values instead of expected returns.
+ Some parts of this implementation were adapted from ACME (https://github.com/deepmind/acme).
+ """
+ def _create_critics(self):
+ """
+ Called in @_create_networks to make critic networks.
+ """
+ assert self.algo_config.critic.distributional.enabled
+ critic_class = ValueNets.DistributionalActionValueNetwork
+ critic_args = dict(
+ obs_shapes=self.obs_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.critic.layer_dims,
+ value_bounds=self.algo_config.critic.value_bounds,
+ num_atoms=self.algo_config.critic.distributional.num_atoms,
+ goal_shapes=self.goal_shapes,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ # Q network ensemble and target ensemble
+ self.nets["critic"] = nn.ModuleList()
+ self.nets["critic_target"] = nn.ModuleList()
+
+ # NOTE: ensemble value in config is ignored, and only 1 critic is used.
+ critic = critic_class(**critic_args)
+ self.nets["critic"].append(critic)
+
+ critic_target = critic_class(**critic_args)
+ self.nets["critic_target"].append(critic_target)
+
+ def _get_target_values_from_sampled_actions(self, next_states_tiled, next_sampled_actions, goal_states_tiled, rewards, dones):
+ """
+ Helper function to get target values for training Q-function with TD-loss. Update from superclass
+ to account for distributional value functions.
+
+ Args:
+ next_states_tiled (dict): next observations to use for sampling actions. Assumes that
+ tiling has already occurred - so that if the batch size is B, and N samples are
+ desired for each observation in the batch, the leading dimension for each
+ observation in the dict is B * N
+
+ next_sampled_actions (torch.Tensor): actions sampled from the action sampler. This function
+ will maximize the critic over these action candidates (using the TD3 trick)
+
+ goal_states_tiled (dict): if not None, goal observations
+
+ rewards (torch.Tensor): batch of rewards - should be shape (B, 1)
+
+ dones (torch.Tensor): batch of done signals - should be shape (B, 1)
+
+ Returns:
+ target_categorical_probabilities (torch.Tensor): target categorical probabilities
+ to use in the bellman backup
+ """
+
+ with torch.no_grad():
+ # compute expected returns of the sampled actions and maximize to find the best action
+ all_vds = self.nets["critic_target"][0].forward_train(next_states_tiled, next_sampled_actions, goal_states_tiled)
+ expected_values = all_vds.mean().reshape(-1, self.algo_config.critic.num_action_samples)
+ best_action_index = torch.argmax(expected_values, dim=1)
+ all_actions = next_sampled_actions.reshape(-1, self.algo_config.critic.num_action_samples, self.ac_dim)
+ best_action = all_actions[torch.arange(all_actions.shape[0]), best_action_index]
+
+ # get the corresponding probabilities for the categorical distributions corresponding to the best actions
+ all_vd_probs = all_vds.probs.reshape(-1, self.algo_config.critic.num_action_samples, self.algo_config.critic.distributional.num_atoms)
+ target_vd_probs = all_vd_probs[torch.arange(all_vd_probs.shape[0]), best_action_index]
+
+ # bellman backup to get a new grid of values - then project onto the canonical atoms to obtain a
+ # target set of categorical probabilities over the atoms
+ atom_value_grid = all_vds.values
+ target_value_grid = rewards + dones * self.discount * atom_value_grid
+ target_categorical_probabilities = LossUtils.project_values_onto_atoms(
+ values=target_value_grid,
+ probabilities=target_vd_probs,
+ atoms=atom_value_grid,
+ )
+
+ return target_categorical_probabilities
+
+ def _compute_critic_loss(self, critic, states, actions, goal_states, q_targets):
+ """
+ Overrides super class to compute a distributional loss. Since values are
+ categorical distributions, this is just computing a cross-entropy
+ loss between the two distributions.
+
+ NOTE: q_targets is expected to be a batch of normalized probability vectors that correspond to
+ the target categorical distributions over the value atoms.
+
+ Args:
+ critic (torch.nn.Module): critic network
+ states (dict): batch of observations
+ actions (torch.Tensor): batch of actions
+ goal_states (dict): if not None, batch of goal observations
+ q_targets (torch.Tensor): batch of target q-values for the TD loss
+
+ Returns:
+ critic_loss (torch.Tensor): critic loss
+ critic_output (dict): additional outputs from the critic. This function
+ returns None, but subclasses may want to provide some information
+ here.
+ """
+
+ # this should be the equivalent of softmax with logits from tf
+ vd = critic.forward_train(states, actions, goal_states)
+ log_probs = F.log_softmax(vd.logits, dim=-1)
+ critic_loss = nn.KLDivLoss(reduction='batchmean')(log_probs, q_targets)
+ return critic_loss, None
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/cql.py b/phantom/submodules/phantom-robomimic/robomimic/algo/cql.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c24d50abd91426a4d96e91896c958b7df1ada0a
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/cql.py
@@ -0,0 +1,668 @@
+"""
+Implementation of Conservative Q-Learning (CQL).
+Based off of https://github.com/aviralkumar2907/CQL.
+(Paper - https://arxiv.org/abs/2006.04779).
+"""
+import numpy as np
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+import robomimic.models.base_nets as BaseNets
+import robomimic.models.obs_nets as ObsNets
+import robomimic.models.policy_nets as PolicyNets
+import robomimic.models.value_nets as ValueNets
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+from robomimic.algo import register_algo_factory_func, ValueAlgo, PolicyAlgo
+
+
+@register_algo_factory_func("cql")
+def algo_config_to_class(algo_config):
+ """
+ Maps algo config to the CQL algo class to instantiate, along with additional algo kwargs.
+
+ Args:
+ algo_config (Config instance): algo config
+
+ Returns:
+ algo_class: subclass of Algo
+ algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
+ """
+ return CQL, {}
+
+
+class CQL(PolicyAlgo, ValueAlgo):
+ """
+ CQL-extension of SAC for the off-policy, offline setting. See https://arxiv.org/abs/2006.04779
+ """
+ def __init__(self, **kwargs):
+ # Store entropy / cql settings first since the super init call requires them
+ self.automatic_entropy_tuning = kwargs["algo_config"].actor.target_entropy is not None
+ self.automatic_cql_tuning = kwargs["algo_config"].critic.target_q_gap is not None and \
+ kwargs["algo_config"].critic.target_q_gap >= 0.0
+
+ # Run super init first
+ super().__init__(**kwargs)
+
+ # Reward settings
+ self.n_step = self.algo_config.n_step
+ self.discount = self.algo_config.discount ** self.n_step
+
+ # Now also store additional SAC- and CQL-specific stuff from the config
+ self._num_batch_steps = 0
+ self.bc_start_steps = self.algo_config.actor.bc_start_steps
+ self.deterministic_backup = self.algo_config.critic.deterministic_backup
+ self.td_loss_fcn = nn.SmoothL1Loss() if self.algo_config.critic.use_huber else nn.MSELoss()
+
+ # Entropy settings
+ self.target_entropy = -np.prod(self.ac_dim) if self.algo_config.actor.target_entropy in {None, "default"} else\
+ self.algo_config.actor.target_entropy
+
+ # CQL settings
+ self.min_q_weight = self.algo_config.critic.min_q_weight
+ self.target_q_gap = self.algo_config.critic.target_q_gap if self.automatic_cql_tuning else 0.0
+
+ @property
+ def log_entropy_weight(self):
+ return self.nets["log_entropy_weight"]() if self.automatic_entropy_tuning else\
+ torch.zeros(1, requires_grad=False, device=self.device)
+
+ @property
+ def log_cql_weight(self):
+ return self.nets["log_cql_weight"]() if self.automatic_cql_tuning else\
+ torch.log(torch.tensor(self.algo_config.critic.cql_weight, requires_grad=False, device=self.device))
+
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+
+ Networks for this algo: critic (potentially ensemble), policy
+ """
+
+ # Create nets
+ self.nets = nn.ModuleDict()
+
+ # Assemble args to pass to actor
+ actor_args = dict(self.algo_config.actor.net.common)
+
+ # Add network-specific args and define network class
+ if self.algo_config.actor.net.type == "gaussian":
+ actor_cls = PolicyNets.GaussianActorNetwork
+ actor_args.update(dict(self.algo_config.actor.net.gaussian))
+ else:
+ # Unsupported actor type!
+ raise ValueError(f"Unsupported actor requested. "
+ f"Requested: {self.algo_config.actor.net.type}, "
+ f"valid options are: {['gaussian']}")
+
+ # Policy
+ self.nets["actor"] = actor_cls(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.actor.layer_dims,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ **actor_args,
+ )
+
+ # Critics
+ self.nets["critic"] = nn.ModuleList()
+ self.nets["critic_target"] = nn.ModuleList()
+ for _ in range(self.algo_config.critic.ensemble.n):
+ for net_list in (self.nets["critic"], self.nets["critic_target"]):
+ critic = ValueNets.ActionValueNetwork(
+ obs_shapes=self.obs_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.critic.layer_dims,
+ value_bounds=self.algo_config.critic.value_bounds,
+ goal_shapes=self.goal_shapes,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+ net_list.append(critic)
+
+ # Entropy (if automatically tuning)
+ if self.automatic_entropy_tuning:
+ self.nets["log_entropy_weight"] = BaseNets.Parameter(torch.zeros(1))
+
+ # CQL (if automatically tuning)
+ if self.automatic_cql_tuning:
+ self.nets["log_cql_weight"] = BaseNets.Parameter(torch.zeros(1))
+
+ # Send networks to appropriate device
+ self.nets = self.nets.float().to(self.device)
+
+ # sync target networks at beginning of training
+ with torch.no_grad():
+ for critic, critic_target in zip(self.nets["critic"], self.nets["critic_target"]):
+ TorchUtils.hard_update(
+ source=critic,
+ target=critic_target,
+ )
+
+ def _create_optimizers(self):
+ """
+ Creates optimizers using @self.optim_params and places them into @self.optimizers.
+
+ Overrides base method since we might need to create aditional optimizers for the entropy
+ and cql weight parameters (by default, the base class only creates optimizers for all
+ entries in @self.nets that have corresponding entries in `self.optim_params` but these
+ parameters do not).
+ """
+
+ # Create actor and critic optimizers via super method
+ super()._create_optimizers()
+
+ # We still need to potentially create additional optimizers based on algo settings
+
+ # entropy (if automatically tuning)
+ if self.automatic_entropy_tuning:
+ self.optimizers["entropy"] = optim.Adam(
+ params=self.nets["log_entropy_weight"].parameters(),
+ lr=self.optim_params["actor"]["learning_rate"]["initial"],
+ weight_decay=0.0,
+ )
+
+ # cql (if automatically tuning)
+ if self.automatic_cql_tuning:
+ self.optimizers["cql"] = optim.Adam(
+ params=self.nets["log_cql_weight"].parameters(),
+ lr=self.optim_params["critic"]["learning_rate"]["initial"],
+ weight_decay=0.0,
+ )
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out relevant info and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+
+ # Make sure the trajectory of actions received is greater than our step horizon
+ assert batch["actions"].shape[1] >= self.n_step
+
+ # remove temporal batches for all
+ input_batch["obs"] = {k: batch["obs"][k][:, 0, :] for k in batch["obs"]}
+ input_batch["next_obs"] = {k: batch["next_obs"][k][:, self.n_step - 1, :] for k in batch["next_obs"]}
+ input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
+ input_batch["actions"] = batch["actions"][:, 0, :]
+
+ # note: ensure scalar signals (rewards, done) retain last dimension of 1 to be compatible with model outputs
+
+ # single timestep reward is discounted sum of intermediate rewards in sequence
+ reward_seq = batch["rewards"][:, :self.n_step]
+ discounts = torch.pow(self.algo_config.discount, torch.arange(self.n_step).float()).unsqueeze(0)
+ input_batch["rewards"] = (reward_seq * discounts).sum(dim=1).unsqueeze(1)
+
+ # consider this n-step seqeunce done if any intermediate dones are present
+ done_seq = batch["dones"][:, :self.n_step]
+ input_batch["dones"] = (done_seq.sum(dim=1) > 0).float().unsqueeze(1)
+
+ # we move to device first before float conversion because image observation modalities will be uint8 -
+ # this minimizes the amount of data transferred to GPU
+ return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ info = OrderedDict()
+
+ # Set the correct context for this training step
+ with TorchUtils.maybe_no_grad(no_grad=validate):
+ # Always run super call first
+ super_info = super().train_on_batch(batch, epoch, validate=validate)
+ # Train actor
+ actor_info = self._train_policy_on_batch(batch, epoch, validate)
+ # Train critic(s)
+ critic_info = self._train_critic_on_batch(batch, epoch, validate)
+ # Update info
+ info.update(super_info)
+ info.update(actor_info)
+ info.update(critic_info)
+
+ # Return stats
+ return info
+
+ def _train_policy_on_batch(self, batch, epoch, validate=False):
+ """
+ Training policy on a single batch of data.
+
+ Loss is the ExpValue over sampled states of the (weighted) logprob of a sampled action
+ under the current policy minus the Q value of associated with the (s, a) combo
+
+ Intuitively, this tries to improve the odds of sampling actions with high Q values while simultaneously
+ penalizing high probability actions.
+
+ Since we're in the continuous setting, we monte carlo sample.
+
+ Concretely:
+ Loss = Average[ entropy_weight * logprob(f(eps; s) | s) - Q(s, f(eps; s) ]
+
+ where we use the reparameterization trick with Gaussian function f(*) to parameterize
+ actions as a function of the sampled noise param eps given input state s
+
+ Additionally, we update the (log) entropy weight parameter if we're tuning that as well.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ info = OrderedDict()
+
+ # Sample actions from policy and get log probs
+ dist = self.nets["actor"].forward_train(obs_dict=batch["obs"], goal_dict=batch["goal_obs"])
+ actions, log_prob = self._get_actions_and_log_prob(dist=dist)
+
+ # Calculate alpha
+ entropy_weight_loss = -(self.log_entropy_weight * (log_prob + self.target_entropy).detach()).mean() if\
+ self.automatic_entropy_tuning else 0.0
+ entropy_weight = self.log_entropy_weight.exp()
+
+ # Get predicted Q-values for all state, action pairs
+ pred_qs = [critic(obs_dict=batch["obs"], acts=actions, goal_dict=batch["goal_obs"])
+ for critic in self.nets["critic"]]
+ # We take the minimum for stability
+ pred_qs, _ = torch.cat(pred_qs, dim=1).min(dim=1, keepdim=True)
+
+ # Use BC if we're in the beginning of training, otherwise calculate policy loss normally
+ baseline = dist.log_prob(batch["actions"]).unsqueeze(dim=-1) if\
+ self._num_batch_steps < self.bc_start_steps else pred_qs
+ policy_loss = (entropy_weight * log_prob - baseline).mean()
+
+ # Add info
+ info["entropy_weight"] = entropy_weight.item()
+ info["entropy_weight_loss"] = entropy_weight_loss.item() if \
+ self.automatic_entropy_tuning else entropy_weight_loss
+ info["actor/loss"] = policy_loss
+
+ # Take a training step if we're not validating
+ if not validate:
+ # Update batch step
+ self._num_batch_steps += 1
+ if self.automatic_entropy_tuning:
+ # Alpha
+ self.optimizers["entropy"].zero_grad()
+ entropy_weight_loss.backward()
+ self.optimizers["entropy"].step()
+ info["entropy_grad_norms"] = self.log_entropy_weight.grad.data.norm(2).pow(2).item()
+
+ # Policy
+ actor_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["actor"],
+ optim=self.optimizers["actor"],
+ loss=policy_loss,
+ max_grad_norm=self.algo_config.actor.max_gradient_norm,
+ )
+ # Add info
+ info["actor/grad_norms"] = actor_grad_norms
+
+ # Return stats
+ return info
+
+ def _train_critic_on_batch(self, batch, epoch, validate=False):
+ """
+ Training critic(s) on a single batch of data.
+
+ For a given batch of (s, a, r, s') tuples and n sampled actions (a_, a'_ corresponding to actions
+ sampled from the learned policy at states s and s', respectively; a~ corresponding to uniformly random
+ sampled actions):
+
+ Loss = CQL_loss + SAC_loss
+
+ Since we're in the continuous setting, we monte carlo sample for all ExpValues, which become Averages instead
+
+ SAC_loss is the standard single-step TD error, corresponding to the following:
+
+ SAC_loss = 0.5 * Average[ (Q(s,a) - (r + Average over a'_ [ Q(s', a'_) ]))^2 ]
+
+ The CQL_loss corresponds to a weighted secondary objective, corresponding to the (ExpValue of Q values over
+ sampled states and sampled actions from the LEARNED policy) minus the (ExpValue of Q values over
+ sampled states and sampled actions from the DATASET policy) plus a regularizer as a function
+ of the learned policy.
+
+ Intuitively, this tries to penalize Q-values arbitrarily resulting from the learned policy (which may produce
+ out-of-distribution (s,a) pairs) while preserving (known) Q-values taken from the dataset policy.
+
+ As we are using SAC, we choose our regularizer to correspond to the negative KL divergence between our
+ learned policy and a uniform distribution such that the first term in the CQL loss corresponds to the
+ soft maximum over all Q values at any state s.
+
+ For stability, we importance sample actions over random actions and from the current policy at s, s'.
+
+ Moreover, if we want to tune the cql_weight automatically, we include the threshold value target_q_gap
+ to penalize Q values that are overly-optimistic by the given threshold.
+
+ In this case, the CQL_loss is as follows:
+
+ CQL_loss = cql_weight * (Average [log (Average over a` in {a~, a_, a_'}: exp(Q(s,a`) - logprob(a`)) - Average [Q(s,a)]] - target_q_gap)
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ info = OrderedDict()
+ B, A = batch["actions"].shape
+ N = self.algo_config.critic.num_random_actions
+
+ # Get predicted Q-values from taken actions
+ q_preds = [critic(obs_dict=batch["obs"], acts=batch["actions"], goal_dict=batch["goal_obs"])
+ for critic in self.nets["critic"]]
+
+ # Sample actions at the current and next step
+ curr_dist = self.nets["actor"].forward_train(obs_dict=batch["obs"], goal_dict=batch["goal_obs"])
+ next_dist = self.nets["actor"].forward_train(obs_dict=batch["next_obs"], goal_dict=batch["goal_obs"])
+ next_actions, next_log_prob = self._get_actions_and_log_prob(dist=next_dist)
+
+ # Don't capture gradients here, since the critic target network doesn't get trained (only soft updated)
+ with torch.no_grad():
+ # We take the max over all samples if the number of action samples is > 1
+ if self.algo_config.critic.num_action_samples > 1:
+ # Generate the target q values, using the backup from the next state
+ temp_actions = next_dist.rsample(sample_shape=(self.algo_config.critic.num_action_samples,)).permute(1, 0, 2)
+ target_qs = [self._get_qs_from_actions(
+ obs_dict=batch["next_obs"], actions=temp_actions, goal_dict=batch["goal_obs"], q_net=critic)
+ .max(dim=1, keepdim=True)[0] for critic in self.nets["critic_target"]]
+ else:
+ target_qs = [critic(obs_dict=batch["next_obs"], acts=next_actions, goal_dict=batch["goal_obs"])
+ for critic in self.nets["critic_target"]]
+ # Take the minimum over all critics
+ target_qs, _ = torch.cat(target_qs, dim=1).min(dim=1, keepdim=True)
+ # If only sampled once from each critic and not using a deterministic backup, subtract the logprob as well
+ if self.algo_config.critic.num_action_samples == 1 and not self.deterministic_backup:
+ target_qs = target_qs - self.log_entropy_weight.exp() * next_log_prob
+
+ # Calculate the q target values
+ done_mask_batch = 1. - batch["dones"]
+ info["done_masks"] = done_mask_batch
+ q_target = batch["rewards"] + done_mask_batch * self.discount * target_qs
+
+ # Calculate CQL stuff
+ cql_random_actions = torch.FloatTensor(N, B, A).uniform_(-1., 1.).to(self.device) # shape (N, B, A)
+ cql_random_log_prob = np.log(0.5 ** A)
+ cql_curr_actions, cql_curr_log_prob = self._get_actions_and_log_prob(dist=curr_dist, sample_shape=(N,)) # shape (N, B, A) and (N, B, 1)
+ cql_next_actions, cql_next_log_prob = self._get_actions_and_log_prob(dist=next_dist, sample_shape=(N,)) # shape (N, B, A) and (N, B, 1)
+ cql_curr_log_prob = cql_curr_log_prob.squeeze(dim=-1).permute(1, 0).detach() # shape (B, N)
+ cql_next_log_prob = cql_next_log_prob.squeeze(dim=-1).permute(1, 0).detach() # shape (B, N)
+ q_cats = [] # Each entry shape will be (B, N)
+
+ for critic, q_pred in zip(self.nets["critic"], q_preds):
+ # Compose Q values over all sampled actions (importance sampled)
+ q_rand = self._get_qs_from_actions(obs_dict=batch["obs"], actions=cql_random_actions.permute(1, 0, 2), goal_dict=batch["goal_obs"], q_net=critic)
+ q_curr = self._get_qs_from_actions(obs_dict=batch["obs"], actions=cql_curr_actions.permute(1, 0, 2), goal_dict=batch["goal_obs"], q_net=critic)
+ q_next = self._get_qs_from_actions(obs_dict=batch["obs"], actions=cql_next_actions.permute(1, 0, 2), goal_dict=batch["goal_obs"], q_net=critic)
+ q_cat = torch.cat([
+ q_rand - cql_random_log_prob,
+ q_next - cql_next_log_prob,
+ q_curr - cql_curr_log_prob,
+ ], dim=1) # shape (B, 3 * N)
+ q_cats.append(q_cat)
+
+ # Calculate the losses for all critics
+ cql_losses = []
+ critic_losses = []
+ cql_weight = torch.clamp(self.log_cql_weight.exp(), min=0.0, max=1000000.0)
+ info["critic/cql_weight"] = cql_weight.item()
+ for i, (q_pred, q_cat) in enumerate(zip(q_preds, q_cats)):
+ # Calculate td error loss
+ td_loss = self.td_loss_fcn(q_pred, q_target)
+ # Calculate cql loss
+ cql_loss = cql_weight * (self.min_q_weight * (torch.logsumexp(q_cat, dim=1).mean() - q_pred.mean()) -
+ self.target_q_gap)
+ cql_losses.append(cql_loss)
+ # Calculate total loss
+ loss = td_loss + cql_loss
+ critic_losses.append(loss)
+ info[f"critic/critic{i+1}_loss"] = loss
+
+ # Run gradient descent if we're not validating
+ if not validate:
+ # Train CQL weight if tuning automatically
+ if self.automatic_cql_tuning:
+ cql_weight_loss = -torch.stack(cql_losses).mean()
+ info[
+ "critic/cql_weight_loss"] = cql_weight_loss.item() # Make sure to not store computation graph since we retain graph after backward() call
+ self.optimizers["cql"].zero_grad()
+ cql_weight_loss.backward(retain_graph=True)
+ self.optimizers["cql"].step()
+ info["critic/cql_grad_norms"] = self.log_cql_weight.grad.data.norm(2).pow(2).item()
+
+ # Train critics
+ for i, (critic_loss, critic, critic_target, optimizer) in enumerate(zip(
+ critic_losses, self.nets["critic"], self.nets["critic_target"], self.optimizers["critic"]
+ )):
+ retain_graph = (i < (len(critic_losses) - 1))
+ critic_grad_norms = TorchUtils.backprop_for_loss(
+ net=critic,
+ optim=optimizer,
+ loss=critic_loss,
+ max_grad_norm=self.algo_config.critic.max_gradient_norm,
+ retain_graph=retain_graph,
+ )
+ info[f"critic/critic{i+1}_grad_norms"] = critic_grad_norms
+ with torch.no_grad():
+ TorchUtils.soft_update(source=critic, target=critic_target, tau=self.algo_config.target_tau)
+
+ # Return stats
+ return info
+
+ def _get_actions_and_log_prob(self, dist, sample_shape=torch.Size()):
+ """
+ Helper method to sample actions and compute corresponding log probabilities
+
+ Args:
+ dist (Distribution): Distribution to sample from
+ sample_shape (torch.Size or tuple): Shape of output when sampling (number of samples)
+
+ Returns:
+ 2-tuple:
+ - (tensor) sampled actions (..., B, ..., A)
+ - (tensor) corresponding log probabilities (..., B, ..., 1)
+ """
+ # Process networks with tanh differently than normal distributions
+ if self.algo_config.actor.net.common.use_tanh:
+ actions, actions_pre_tanh = dist.rsample(sample_shape=sample_shape, return_pretanh_value=True)
+ log_prob = dist.log_prob(actions, pre_tanh_value=actions_pre_tanh).unsqueeze(dim=-1)
+ else:
+ actions = dist.rsample(sample_shape=sample_shape)
+ log_prob = dist.log_prob(actions)
+
+ return actions, log_prob
+
+ @staticmethod
+ def _get_qs_from_actions(obs_dict, actions, goal_dict, q_net):
+ """
+ Helper function for grabbing Q values given a single state and multiple (N) sampled actions.
+
+ Args:
+ obs_dict (dict): Observation dict from batch
+ actions (tensor): Torch tensor, with dim1 assumed to be the extra sampled dimension
+ goal_dict (dict): Goal dict from batch
+ q_net (nn.Module): Q net to pass the observations and actions
+
+ Returns:
+ tensor: (B, N) corresponding Q values
+ """
+ # Get the number of sampled actions
+ B, N, D = actions.shape
+
+ # Repeat obs and goals in the batch dimension
+ obs_dict_stacked = ObsUtils.repeat_and_stack_observation(obs_dict, N)
+ goal_dict_stacked = ObsUtils.repeat_and_stack_observation(goal_dict, N)
+
+ # Pass the obs and (flattened) actions through to get the Q values
+ qs = q_net(obs_dict=obs_dict_stacked, acts=actions.reshape(-1, D), goal_dict=goal_dict_stacked)
+
+ # Unflatten output
+ qs = qs.reshape(B, N)
+
+ return qs
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ loss_log = OrderedDict()
+
+ # record current optimizer learning rates
+ for k in self.optimizers:
+ keys = [k]
+ optims = [self.optimizers[k]]
+ if k == "critic":
+ # account for critic having one optimizer per ensemble member
+ keys = ["{}{}".format(k, critic_ind) for critic_ind in range(len(self.nets["critic"]))]
+ optims = self.optimizers[k]
+ for kp, optimizer in zip(keys, optims):
+ for i, param_group in enumerate(optimizer.param_groups):
+ loss_log["Optimizer/{}{}_lr".format(kp, i)] = param_group["lr"]
+
+ # extract relevant logs for critic, and actor
+ loss_log["Loss"] = 0.
+ for loss_logger in [self._log_critic_info, self._log_actor_info]:
+ this_log = loss_logger(info)
+ if "Loss" in this_log:
+ # manually merge total loss
+ loss_log["Loss"] += this_log["Loss"]
+ del this_log["Loss"]
+ loss_log.update(this_log)
+
+ return loss_log
+
+ def _log_critic_info(self, info):
+ """
+ Helper function to extract critic-relevant information for logging.
+ """
+ loss_log = OrderedDict()
+ if "done_masks" in info:
+ loss_log["Critic/Done_Mask_Percentage"] = 100. * torch.mean(info["done_masks"]).item()
+ if "critic/q_targets" in info:
+ loss_log["Critic/Q_Targets"] = info["critic/q_targets"].mean().item()
+ loss_log["Loss"] = 0.
+ for critic_ind in range(len(self.nets["critic"])):
+ loss_log["Critic/Critic{}_Loss".format(critic_ind + 1)] = info["critic/critic{}_loss".format(critic_ind + 1)].item()
+ if "critic/critic{}_grad_norms".format(critic_ind + 1) in info:
+ loss_log["Critic/Critic{}_Grad_Norms".format(critic_ind + 1)] = info["critic/critic{}_grad_norms".format(critic_ind + 1)]
+ loss_log["Loss"] += loss_log["Critic/Critic{}_Loss".format(critic_ind + 1)]
+ if "critic/cql_weight_loss" in info:
+ loss_log["Critic/CQL_Weight"] = info["critic/cql_weight"]
+ loss_log["Critic/CQL_Weight_Loss"] = info["critic/cql_weight_loss"]
+ loss_log["Critic/CQL_Grad_Norms"] = info["critic/cql_grad_norms"]
+ return loss_log
+
+ def _log_actor_info(self, info):
+ """
+ Helper function to extract actor-relevant information for logging.
+ """
+ loss_log = OrderedDict()
+ loss_log["Actor/Loss"] = info["actor/loss"].item()
+ if "actor/grad_norms" in info:
+ loss_log["Actor/Grad_Norms"] = info["actor/grad_norms"]
+ loss_log["Loss"] = loss_log["Actor/Loss"]
+ loss_log["Entropy_Weight_Loss"] = info["entropy_weight_loss"]
+ loss_log["Entropy_Weight"] = info["entropy_weight"]
+ if "entropy_grad_norms" in info:
+ loss_log["Entropy_Grad_Norms"] = info["entropy_grad_norms"]
+ return loss_log
+
+ def set_train(self):
+ """
+ Prepare networks for evaluation. Update from super class to make sure
+ target networks stay in evaluation mode all the time.
+ """
+ self.nets.train()
+
+ # target networks always in eval
+ for critic in self.nets["critic_target"]:
+ critic.eval()
+
+ def on_epoch_end(self, epoch):
+ """
+ Called at the end of each epoch.
+ """
+
+ # LR scheduling updates
+ for lr_sc in self.lr_schedulers["critic"]:
+ if lr_sc is not None:
+ lr_sc.step()
+
+ if self.lr_schedulers["actor"] is not None:
+ self.lr_schedulers["actor"].step()
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ assert not self.nets.training
+
+ return self.nets["actor"](obs_dict=obs_dict, goal_dict=goal_dict)
+
+ def get_state_action_value(self, obs_dict, actions, goal_dict=None):
+ """
+ Get state-action value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ actions (torch.Tensor): action
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ assert not self.nets.training
+
+ return self.nets["critic"][0](obs_dict, actions, goal_dict)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/diffusion_policy.py b/phantom/submodules/phantom-robomimic/robomimic/algo/diffusion_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..5262ae8b2aac4cc4f8e947fd4d6b8b513d8a83fb
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/diffusion_policy.py
@@ -0,0 +1,693 @@
+"""
+Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi
+"""
+from typing import Callable, Union
+import math
+from collections import OrderedDict, deque
+from packaging.version import parse as parse_version
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# requires diffusers==0.11.1
+from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
+from diffusers.schedulers.scheduling_ddim import DDIMScheduler
+from diffusers.training_utils import EMAModel
+
+import robomimic.models.obs_nets as ObsNets
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+import robomimic.utils.obs_utils as ObsUtils
+
+from robomimic.algo import register_algo_factory_func, PolicyAlgo
+
+@register_algo_factory_func("diffusion_policy")
+def algo_config_to_class(algo_config):
+ """
+ Maps algo config to the BC algo class to instantiate, along with additional algo kwargs.
+
+ Args:
+ algo_config (Config instance): algo config
+
+ Returns:
+ algo_class: subclass of Algo
+ algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
+ """
+
+ if algo_config.unet.enabled:
+ return DiffusionPolicyUNet, {}
+ elif algo_config.transformer.enabled:
+ raise NotImplementedError()
+ else:
+ raise RuntimeError()
+
+class DiffusionPolicyUNet(PolicyAlgo):
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ # set up different observation groups for @MIMO_MLP
+ observation_group_shapes = OrderedDict()
+ observation_group_shapes["obs"] = OrderedDict(self.obs_shapes)
+ encoder_kwargs = ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder)
+
+ obs_encoder = ObsNets.ObservationGroupEncoder(
+ observation_group_shapes=observation_group_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+ # IMPORTANT!
+ # replace all BatchNorm with GroupNorm to work with EMA
+ # performance will tank if you forget to do this!
+ obs_encoder = replace_bn_with_gn(obs_encoder)
+
+ obs_dim = obs_encoder.output_shape()[0]
+
+ # create network object
+ noise_pred_net = ConditionalUnet1D(
+ input_dim=self.ac_dim,
+ global_cond_dim=obs_dim*self.algo_config.horizon.observation_horizon
+ )
+
+ # the final arch has 2 parts
+ nets = nn.ModuleDict({
+ 'policy': nn.ModuleDict({
+ 'obs_encoder': obs_encoder,
+ 'noise_pred_net': noise_pred_net
+ })
+ })
+
+ nets = nets.float().to(self.device)
+
+ # setup noise scheduler
+ noise_scheduler = None
+ if self.algo_config.ddpm.enabled:
+ noise_scheduler = DDPMScheduler(
+ num_train_timesteps=self.algo_config.ddpm.num_train_timesteps,
+ beta_schedule=self.algo_config.ddpm.beta_schedule,
+ clip_sample=self.algo_config.ddpm.clip_sample,
+ prediction_type=self.algo_config.ddpm.prediction_type
+ )
+ elif self.algo_config.ddim.enabled:
+ noise_scheduler = DDIMScheduler(
+ num_train_timesteps=self.algo_config.ddim.num_train_timesteps,
+ beta_schedule=self.algo_config.ddim.beta_schedule,
+ clip_sample=self.algo_config.ddim.clip_sample,
+ set_alpha_to_one=self.algo_config.ddim.set_alpha_to_one,
+ steps_offset=self.algo_config.ddim.steps_offset,
+ prediction_type=self.algo_config.ddim.prediction_type
+ )
+ else:
+ raise RuntimeError()
+
+ # setup EMA
+ ema = None
+ if self.algo_config.ema.enabled:
+ ema = EMAModel(parameters=nets.parameters(), power=self.algo_config.ema.power)
+
+ # set attrs
+ self.nets = nets
+ self._shadow_nets = copy.deepcopy(self.nets).eval()
+ self._shadow_nets.requires_grad_(False)
+ self.noise_scheduler = noise_scheduler
+ self.ema = ema
+ self.action_check_done = False
+ self.obs_queue = None
+ self.action_queue = None
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ To = self.algo_config.horizon.observation_horizon
+ Ta = self.algo_config.horizon.action_horizon
+ Tp = self.algo_config.horizon.prediction_horizon
+
+ input_batch = dict()
+ input_batch["obs"] = {k: batch["obs"][k][:, :To, :] for k in batch["obs"]}
+ input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
+ input_batch["actions"] = batch["actions"][:, :Tp, :]
+
+ # check if actions are normalized to [-1,1]
+ if not self.action_check_done:
+ actions = input_batch["actions"]
+ in_range = (-1 <= actions) & (actions <= 1)
+ all_in_range = torch.all(in_range).item()
+ if not all_in_range:
+ raise ValueError('"actions" must be in range [-1,1] for Diffusion Policy! Check if hdf5_normalize_action is enabled.')
+ self.action_check_done = True
+
+ return TensorUtils.to_device(TensorUtils.to_float(input_batch), self.device)
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ To = self.algo_config.horizon.observation_horizon
+ Ta = self.algo_config.horizon.action_horizon
+ Tp = self.algo_config.horizon.prediction_horizon
+ action_dim = self.ac_dim
+ B = batch['actions'].shape[0]
+
+
+ with TorchUtils.maybe_no_grad(no_grad=validate):
+ info = super(DiffusionPolicyUNet, self).train_on_batch(batch, epoch, validate=validate)
+ actions = batch['actions']
+
+ # encode obs
+ inputs = {
+ 'obs': batch["obs"],
+ 'goal': batch["goal_obs"]
+ }
+ for k in self.obs_shapes:
+ # first two dimensions should be [B, T] for inputs
+ assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k])
+
+ obs_features = TensorUtils.time_distributed(inputs, self.nets['policy']['obs_encoder'], inputs_as_kwargs=True)
+ assert obs_features.ndim == 3 # [B, T, D]
+
+ obs_cond = obs_features.flatten(start_dim=1)
+
+ # sample noise to add to actions
+ noise = torch.randn(actions.shape, device=self.device)
+
+ # sample a diffusion iteration for each data point
+ timesteps = torch.randint(
+ 0, self.noise_scheduler.config.num_train_timesteps,
+ (B,), device=self.device
+ ).long()
+
+ # add noise to the clean actions according to the noise magnitude at each diffusion iteration
+ # (this is the forward diffusion process)
+ noisy_actions = self.noise_scheduler.add_noise(
+ actions, noise, timesteps)
+
+ # predict the noise residual
+ noise_pred = self.nets['policy']['noise_pred_net'](
+ noisy_actions, timesteps, global_cond=obs_cond)
+
+ # L2 loss
+ loss = F.mse_loss(noise_pred, noise)
+
+ # logging
+ losses = {
+ 'l2_loss': loss
+ }
+ info["losses"] = TensorUtils.detach(losses)
+
+ if not validate:
+ # gradient step
+ policy_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets,
+ optim=self.optimizers["policy"],
+ loss=loss,
+ )
+
+ # update Exponential Moving Average of the model weights
+ if self.ema is not None:
+ self.ema.step(self.nets.parameters())
+
+ step_info = {
+ 'policy_grad_norms': policy_grad_norms
+ }
+ info.update(step_info)
+
+ return info
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ log = super(DiffusionPolicyUNet, self).log_info(info)
+ log["Loss"] = info["losses"]["l2_loss"].item()
+ if "policy_grad_norms" in info:
+ log["Policy_Grad_Norms"] = info["policy_grad_norms"]
+ return log
+
+ def reset(self):
+ """
+ Reset algo state to prepare for environment rollouts.
+ """
+ # setup inference queues
+ To = self.algo_config.horizon.observation_horizon
+ Ta = self.algo_config.horizon.action_horizon
+ obs_queue = deque(maxlen=To)
+ action_queue = deque(maxlen=Ta)
+ self.obs_queue = obs_queue
+ self.action_queue = action_queue
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation [1, Do]
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor [1, Da]
+ """
+ # obs_dict: key: [1,D]
+ To = self.algo_config.horizon.observation_horizon
+ Ta = self.algo_config.horizon.action_horizon
+
+ # TODO: obs_queue already handled by frame_stack
+ # make sure we have at least To observations in obs_queue
+ # if not enough, repeat
+ # if already full, append one to the obs_queue
+ # n_repeats = max(To - len(self.obs_queue), 1)
+ # self.obs_queue.extend([obs_dict] * n_repeats)
+
+ if len(self.action_queue) == 0:
+ # no actions left, run inference
+ # turn obs_queue into dict of tensors (concat at T dim)
+ # import pdb; pdb.set_trace()
+ # obs_dict_list = TensorUtils.list_of_flat_dict_to_dict_of_list(list(self.obs_queue))
+ # obs_dict_tensor = dict((k, torch.cat(v, dim=0).unsqueeze(0)) for k,v in obs_dict_list.items())
+
+ # run inference
+ # [1,T,Da]
+ action_sequence = self._get_action_trajectory(obs_dict=obs_dict)
+
+ # put actions into the queue
+ self.action_queue.extend(action_sequence[0])
+
+ # has action, execute from left to right
+ # [Da]
+ action = self.action_queue.popleft()
+
+ # [1,Da]
+ action = action.unsqueeze(0)
+ return action
+
+ def _get_action_trajectory(self, obs_dict, goal_dict=None):
+ assert not self.nets.training
+ To = self.algo_config.horizon.observation_horizon
+ Ta = self.algo_config.horizon.action_horizon
+ Tp = self.algo_config.horizon.prediction_horizon
+ action_dim = self.ac_dim
+ if self.algo_config.ddpm.enabled is True:
+ num_inference_timesteps = self.algo_config.ddpm.num_inference_timesteps
+ elif self.algo_config.ddim.enabled is True:
+ num_inference_timesteps = self.algo_config.ddim.num_inference_timesteps
+ else:
+ raise ValueError
+
+ # select network
+ nets = self.nets
+ if self.ema is not None:
+ self.ema.copy_to(parameters=self._shadow_nets.parameters())
+ nets = self._shadow_nets
+
+ # encode obs
+ inputs = {
+ 'obs': obs_dict,
+ 'goal': goal_dict
+ }
+ for k in self.obs_shapes:
+ # first two dimensions should be [B, T] for inputs
+ assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k])
+ obs_features = TensorUtils.time_distributed(inputs, self.nets['policy']['obs_encoder'], inputs_as_kwargs=True)
+ assert obs_features.ndim == 3 # [B, T, D]
+ B = obs_features.shape[0]
+
+ # reshape observation to (B,obs_horizon*obs_dim)
+ obs_cond = obs_features.flatten(start_dim=1)
+
+ # initialize action from Guassian noise
+ noisy_action = torch.randn(
+ (B, Tp, action_dim), device=self.device)
+ naction = noisy_action
+
+ # init scheduler
+ self.noise_scheduler.set_timesteps(num_inference_timesteps)
+
+ for k in self.noise_scheduler.timesteps:
+ # predict noise
+ noise_pred = nets['policy']['noise_pred_net'](
+ sample=naction,
+ timestep=k,
+ global_cond=obs_cond
+ )
+
+ # inverse diffusion step (remove noise)
+ naction = self.noise_scheduler.step(
+ model_output=noise_pred,
+ timestep=k,
+ sample=naction
+ ).prev_sample
+
+ # process action using Ta
+ start = To - 1
+ end = start + Ta
+ action = naction[:,start:end]
+ return action
+
+ def serialize(self):
+ """
+ Get dictionary of current model parameters.
+ """
+ return {
+ "nets": self.nets.state_dict(),
+ "ema": self.ema.state_dict() if self.ema is not None else None,
+ }
+
+ def deserialize(self, model_dict):
+ """
+ Load model from a checkpoint.
+
+ Args:
+ model_dict (dict): a dictionary saved by self.serialize() that contains
+ the same keys as @self.network_classes
+ """
+ self.nets.load_state_dict(model_dict["nets"])
+ if model_dict.get("ema", None) is not None:
+ self.ema.load_state_dict(model_dict["ema"])
+
+
+# =================== Vision Encoder Utils =====================
+def replace_submodules(
+ root_module: nn.Module,
+ predicate: Callable[[nn.Module], bool],
+ func: Callable[[nn.Module], nn.Module]) -> nn.Module:
+ """
+ Replace all submodules selected by the predicate with
+ the output of func.
+
+ predicate: Return true if the module is to be replaced.
+ func: Return new module to use.
+ """
+ if predicate(root_module):
+ return func(root_module)
+
+ if parse_version(torch.__version__) < parse_version('1.9.0'):
+ raise ImportError('This function requires pytorch >= 1.9.0')
+
+ bn_list = [k.split('.') for k, m
+ in root_module.named_modules(remove_duplicate=True)
+ if predicate(m)]
+ for *parent, k in bn_list:
+ parent_module = root_module
+ if len(parent) > 0:
+ parent_module = root_module.get_submodule('.'.join(parent))
+ if isinstance(parent_module, nn.Sequential):
+ src_module = parent_module[int(k)]
+ else:
+ src_module = getattr(parent_module, k)
+ tgt_module = func(src_module)
+ if isinstance(parent_module, nn.Sequential):
+ parent_module[int(k)] = tgt_module
+ else:
+ setattr(parent_module, k, tgt_module)
+ # verify that all modules are replaced
+ bn_list = [k.split('.') for k, m
+ in root_module.named_modules(remove_duplicate=True)
+ if predicate(m)]
+ assert len(bn_list) == 0
+ return root_module
+
+def replace_bn_with_gn(
+ root_module: nn.Module,
+ features_per_group: int=16) -> nn.Module:
+ """
+ Relace all BatchNorm layers with GroupNorm.
+ """
+ replace_submodules(
+ root_module=root_module,
+ predicate=lambda x: isinstance(x, nn.BatchNorm2d),
+ func=lambda x: nn.GroupNorm(
+ num_groups=x.num_features//features_per_group,
+ num_channels=x.num_features)
+ )
+ return root_module
+
+# =================== UNet for Diffusion ==============
+
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+class Downsample1d(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+class Upsample1d(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Conv1dBlock(nn.Module):
+ '''
+ Conv1d --> GroupNorm --> Mish
+ '''
+
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
+ super().__init__()
+
+ self.block = nn.Sequential(
+ nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
+ nn.GroupNorm(n_groups, out_channels),
+ nn.Mish(),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class ConditionalResidualBlock1D(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ cond_dim,
+ kernel_size=3,
+ n_groups=8):
+ super().__init__()
+
+ self.blocks = nn.ModuleList([
+ Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
+ Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
+ ])
+
+ # FiLM modulation https://arxiv.org/abs/1709.07871
+ # predicts per-channel scale and bias
+ cond_channels = out_channels * 2
+ self.out_channels = out_channels
+ self.cond_encoder = nn.Sequential(
+ nn.Mish(),
+ nn.Linear(cond_dim, cond_channels),
+ nn.Unflatten(-1, (-1, 1))
+ )
+
+ # make sure dimensions compatible
+ self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
+ if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x, cond):
+ '''
+ x : [ batch_size x in_channels x horizon ]
+ cond : [ batch_size x cond_dim]
+
+ returns:
+ out : [ batch_size x out_channels x horizon ]
+ '''
+ out = self.blocks[0](x)
+ embed = self.cond_encoder(cond)
+
+ embed = embed.reshape(
+ embed.shape[0], 2, self.out_channels, 1)
+ scale = embed[:,0,...]
+ bias = embed[:,1,...]
+ out = scale * out + bias
+
+ out = self.blocks[1](out)
+ out = out + self.residual_conv(x)
+ return out
+
+
+class ConditionalUnet1D(nn.Module):
+ def __init__(self,
+ input_dim,
+ global_cond_dim,
+ diffusion_step_embed_dim=256,
+ down_dims=[256,512,1024],
+ kernel_size=5,
+ n_groups=8
+ ):
+ """
+ input_dim: Dim of actions.
+ global_cond_dim: Dim of global conditioning applied with FiLM
+ in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
+ diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
+ down_dims: Channel size for each UNet level.
+ The length of this array determines numebr of levels.
+ kernel_size: Conv kernel size
+ n_groups: Number of groups for GroupNorm
+ """
+
+ super().__init__()
+ all_dims = [input_dim] + list(down_dims)
+ start_dim = down_dims[0]
+
+ dsed = diffusion_step_embed_dim
+ diffusion_step_encoder = nn.Sequential(
+ SinusoidalPosEmb(dsed),
+ nn.Linear(dsed, dsed * 4),
+ nn.Mish(),
+ nn.Linear(dsed * 4, dsed),
+ )
+ cond_dim = dsed + global_cond_dim
+
+ in_out = list(zip(all_dims[:-1], all_dims[1:]))
+ mid_dim = all_dims[-1]
+ self.mid_modules = nn.ModuleList([
+ ConditionalResidualBlock1D(
+ mid_dim, mid_dim, cond_dim=cond_dim,
+ kernel_size=kernel_size, n_groups=n_groups
+ ),
+ ConditionalResidualBlock1D(
+ mid_dim, mid_dim, cond_dim=cond_dim,
+ kernel_size=kernel_size, n_groups=n_groups
+ ),
+ ])
+
+ down_modules = nn.ModuleList([])
+ for ind, (dim_in, dim_out) in enumerate(in_out):
+ is_last = ind >= (len(in_out) - 1)
+ down_modules.append(nn.ModuleList([
+ ConditionalResidualBlock1D(
+ dim_in, dim_out, cond_dim=cond_dim,
+ kernel_size=kernel_size, n_groups=n_groups),
+ ConditionalResidualBlock1D(
+ dim_out, dim_out, cond_dim=cond_dim,
+ kernel_size=kernel_size, n_groups=n_groups),
+ Downsample1d(dim_out) if not is_last else nn.Identity()
+ ]))
+
+ up_modules = nn.ModuleList([])
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
+ is_last = ind >= (len(in_out) - 1)
+ up_modules.append(nn.ModuleList([
+ ConditionalResidualBlock1D(
+ dim_out*2, dim_in, cond_dim=cond_dim,
+ kernel_size=kernel_size, n_groups=n_groups),
+ ConditionalResidualBlock1D(
+ dim_in, dim_in, cond_dim=cond_dim,
+ kernel_size=kernel_size, n_groups=n_groups),
+ Upsample1d(dim_in) if not is_last else nn.Identity()
+ ]))
+
+ final_conv = nn.Sequential(
+ Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
+ nn.Conv1d(start_dim, input_dim, 1),
+ )
+
+ self.diffusion_step_encoder = diffusion_step_encoder
+ self.up_modules = up_modules
+ self.down_modules = down_modules
+ self.final_conv = final_conv
+
+ print("number of parameters: {:e}".format(
+ sum(p.numel() for p in self.parameters()))
+ )
+
+ def forward(self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ global_cond=None):
+ """
+ x: (B,T,input_dim)
+ timestep: (B,) or int, diffusion step
+ global_cond: (B,global_cond_dim)
+ output: (B,T,input_dim)
+ """
+ # (B,T,C)
+ sample = sample.moveaxis(-1,-2)
+ # (B,C,T)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ global_feature = self.diffusion_step_encoder(timesteps)
+
+ if global_cond is not None:
+ global_feature = torch.cat([
+ global_feature, global_cond
+ ], axis=-1)
+
+ x = sample
+ h = []
+ for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
+ x = resnet(x, global_feature)
+ x = resnet2(x, global_feature)
+ h.append(x)
+ x = downsample(x)
+
+ for mid_module in self.mid_modules:
+ x = mid_module(x, global_feature)
+
+ for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
+ x = torch.cat((x, h.pop()), dim=1)
+ x = resnet(x, global_feature)
+ x = resnet2(x, global_feature)
+ x = upsample(x)
+
+ x = self.final_conv(x)
+
+ # (B,C,T)
+ x = x.moveaxis(-1,-2)
+ # (B,T,C)
+ return x
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/gl.py b/phantom/submodules/phantom-robomimic/robomimic/algo/gl.py
new file mode 100644
index 0000000000000000000000000000000000000000..24ae800892ee0866f9b4df3d94ff49eb1cd8d112
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/gl.py
@@ -0,0 +1,775 @@
+"""
+Subgoal prediction models, used in HBC / IRIS.
+"""
+import numpy as np
+from collections import OrderedDict
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+
+import robomimic.models.obs_nets as ObsNets
+import robomimic.models.vae_nets as VAENets
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+import robomimic.utils.obs_utils as ObsUtils
+
+from robomimic.algo import register_algo_factory_func, PlannerAlgo, ValueAlgo
+
+
+@register_algo_factory_func("gl")
+def algo_config_to_class(algo_config):
+ """
+ Maps algo config to the GL algo class to instantiate, along with additional algo kwargs.
+
+ Args:
+ algo_config (Config instance): algo config
+
+ Returns:
+ algo_class: subclass of Algo
+ algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
+ """
+ if algo_config.vae.enabled:
+ return GL_VAE, {}
+ return GL, {}
+
+
+class GL(PlannerAlgo):
+ """
+ Implements goal prediction component for HBC and IRIS.
+ """
+ def __init__(
+ self,
+ algo_config,
+ obs_config,
+ global_config,
+ obs_key_shapes,
+ ac_dim,
+ device
+ ):
+ """
+ Args:
+ algo_config (Config object): instance of Config corresponding to the algo section
+ of the config
+
+ obs_config (Config object): instance of Config corresponding to the observation
+ section of the config
+
+ global_config (Config object): global training config
+
+ obs_key_shapes (OrderedDict): dictionary that maps observation keys to shapes
+
+ ac_dim (int): dimension of action space
+
+ device (torch.Device): where the algo should live (i.e. cpu, gpu)
+ """
+
+ self._subgoal_horizon = algo_config.subgoal_horizon
+ super(GL, self).__init__(
+ algo_config=algo_config,
+ obs_config=obs_config,
+ global_config=global_config,
+ obs_key_shapes=obs_key_shapes,
+ ac_dim=ac_dim,
+ device=device
+ )
+
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ self.nets = nn.ModuleDict()
+
+ obs_group_shapes = OrderedDict()
+ obs_group_shapes["obs"] = OrderedDict(self.obs_shapes)
+ if len(self.goal_shapes) > 0:
+ obs_group_shapes["goal"] = OrderedDict(self.goal_shapes)
+
+ # deterministic goal prediction network
+ self.nets["goal_network"] = ObsNets.MIMO_MLP(
+ input_obs_group_shapes=obs_group_shapes,
+ output_shapes=self.subgoal_shapes,
+ layer_dims=self.algo_config.ae.planner_layer_dims,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ self.nets = self.nets.float().to(self.device)
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+
+ # remove temporal batches for all except scalar signals (to be compatible with model outputs)
+ input_batch["obs"] = { k: batch["obs"][k][:, 0, :] for k in batch["obs"] }
+ # extract multi-horizon subgoal target
+ input_batch["subgoals"] = {k: batch["next_obs"][k][:, self._subgoal_horizon - 1, :] for k in batch["next_obs"]}
+ input_batch["target_subgoals"] = input_batch["subgoals"]
+ input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
+
+ # we move to device first before float conversion because image observation modalities will be uint8 -
+ # this minimizes the amount of data transferred to GPU
+ return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
+
+ def get_actor_goal_for_training_from_processed_batch(self, processed_batch, **kwargs):
+ """
+ Retrieve subgoals from processed batch to use for training the actor. Subclasses
+ can modify this function to change the subgoals.
+
+ Args:
+ processed_batch (dict): processed batch from @process_batch_for_training
+
+ Returns:
+ actor_subgoals (dict): subgoal observations to condition actor on
+ """
+ return processed_batch["target_subgoals"]
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ with TorchUtils.maybe_no_grad(no_grad=validate):
+ info = super(GL, self).train_on_batch(batch, epoch, validate=validate)
+
+ # predict subgoal observations with goal network
+ pred_subgoals = self.nets["goal_network"](obs=batch["obs"], goal=batch["goal_obs"])
+
+ # compute loss as L2 error for each observation key
+ losses = OrderedDict()
+ target_subgoals = batch["target_subgoals"] # targets for network prediction
+ goal_loss = 0.
+ for k in pred_subgoals:
+ assert pred_subgoals[k].shape == target_subgoals[k].shape, "mismatch in predicted and target subgoals!"
+ mode_loss = nn.MSELoss()(pred_subgoals[k], target_subgoals[k])
+ goal_loss += mode_loss
+ losses["goal_{}_loss".format(k)] = mode_loss
+ losses["goal_loss"] = goal_loss
+ info.update(TensorUtils.detach(losses))
+
+ if not validate:
+ # gradient step
+ goal_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["goal_network"],
+ optim=self.optimizers["goal_network"],
+ loss=losses["goal_loss"],
+ )
+ info["goal_grad_norms"] = goal_grad_norms
+
+ return info
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ loss_log = super(GL, self).log_info(info)
+
+ loss_log["Loss"] = info["goal_loss"].item()
+ for k in info:
+ if k.endswith("_loss"):
+ loss_log[k] = info[k].item()
+ if "goal_grad_norms" in info:
+ loss_log["Grad_Norms"] = info["goal_grad_norms"]
+
+ return loss_log
+
+ def get_subgoal_predictions(self, obs_dict, goal_dict=None):
+ """
+ Takes a batch of observations and predicts a batch of subgoals.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ subgoal prediction (dict): name -> Tensor [batch_size, ...]
+ """
+ return self.nets["goal_network"](obs=obs_dict, goal=goal_dict)
+
+ def sample_subgoals(self, obs_dict, goal_dict=None, num_samples=1):
+ """
+ Sample @num_samples subgoals from the network per observation.
+ Since this class implements a deterministic subgoal prediction,
+ this function returns identical subgoals for each input observation.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ subgoals (dict): name -> Tensor [batch_size, num_samples, ...]
+ """
+
+ # stack observations to get all samples in one forward pass
+ obs_tiled = ObsUtils.repeat_and_stack_observation(obs_dict, n=num_samples)
+ goal_tiled = None
+ if goal_dict is not None:
+ goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_samples)
+
+ # [batch_size * num_samples, ...]
+ goals = self.get_subgoal_predictions(obs_dict=obs_tiled, goal_dict=goal_tiled)
+ # reshape to [batch_size, num_samples, ...]
+ return TensorUtils.reshape_dimensions(goals, begin_axis=0, end_axis=0, target_dims=(-1, num_samples))
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs. Assumes one input observation (first dimension should be 1).
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ raise Exception("Rollouts are not supported by GL")
+
+
+class GL_VAE(GL):
+ """
+ Implements goal prediction via VAE.
+ """
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ self.nets = nn.ModuleDict()
+
+ self.nets["goal_network"] = VAENets.VAE(
+ input_shapes=self.subgoal_shapes,
+ output_shapes=self.subgoal_shapes,
+ condition_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ device=self.device,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ **VAENets.vae_args_from_config(self.algo_config.vae),
+ )
+
+ self.nets = self.nets.float().to(self.device)
+
+ def get_actor_goal_for_training_from_processed_batch(
+ self,
+ processed_batch,
+ use_latent_subgoals=False,
+ use_prior_correction=False,
+ num_prior_samples=100,
+ **kwargs,
+ ):
+ """
+ Modify from superclass to support a @use_latent_subgoals option.
+ The VAE can optionally return latent subgoals by passing the subgoal
+ observations in the batch through the encoder.
+
+ Args:
+ processed_batch (dict): processed batch from @process_batch_for_training
+
+ use_latent_subgoals (bool): if True, condition the actor on latent subgoals
+ by using the VAE encoder to encode subgoal observations at train-time,
+ and using the VAE prior to generate latent subgoals at test-time
+
+ use_prior_correction (bool): if True, use a "prior correction" trick to
+ choose a latent subgoal sampled from the prior that is close to the
+ latent from the VAE encoder (posterior). This can help with issues at
+ test-time where the encoder latent distribution might not match
+ the prior latent distribution.
+
+ num_prior_samples (int): number of VAE prior samples to take and choose among,
+ if @use_prior_correction is true
+
+ Returns:
+ actor_subgoals (dict): subgoal observations to condition actor on
+ """
+
+ if not use_latent_subgoals:
+ return processed_batch["target_subgoals"]
+
+ # batch variables
+ obs = processed_batch["obs"]
+ subgoals = processed_batch["subgoals"] # full subgoal observations
+ target_subgoals = processed_batch["target_subgoals"] # targets for network prediction
+ goal_obs = processed_batch["goal_obs"]
+
+ with torch.no_grad():
+ # run VAE forward pass to get samples from posterior for the current observation and subgoal
+ vae_outputs = self.nets["goal_network"](
+ inputs=subgoals, # encoder takes full subgoals
+ outputs=target_subgoals, # reconstruct target subgoals
+ goals=goal_obs,
+ conditions=obs, # condition on observations
+ )
+ posterior_z = vae_outputs["encoder_z"]
+ latent_subgoals = posterior_z
+
+ if use_prior_correction:
+ # instead of treating posterior samples as latent subgoals, sample latents from
+ # the prior and choose the closest one as the latent subgoal
+
+ random_key = list(obs.keys())[0]
+ batch_size = obs[random_key].shape[0]
+
+ # for each batch member, get @num_prior_samples samples from the prior
+ obs_tiled = ObsUtils.repeat_and_stack_observation(obs, n=num_prior_samples)
+ goal_tiled = None
+ if len(self.goal_shapes) > 0:
+ goal_tiled = ObsUtils.repeat_and_stack_observation(goal_obs, n=num_prior_samples)
+
+ prior_z_samples = self.nets["goal_network"].sample_prior(
+ conditions=obs_tiled,
+ goals=goal_tiled,
+ )
+
+ # choose prior samples that are closest to the sampled posterior latents
+ # note: every posterior sample in the batch has @num_prior_samples corresponding prior samples
+
+ # reshape prior samples to (batch_size, num_samples, latent_dim)
+ prior_z_samples = prior_z_samples.reshape(batch_size, num_prior_samples, -1)
+
+ # reshape posterior latents to (batch_size, 1, latent_dim)
+ posterior_z_expanded = posterior_z.unsqueeze(1)
+
+ # compute distances with broadcasting so that each posterior sample
+ # has distances to all of its prior samples
+ distances = (prior_z_samples - posterior_z_expanded).pow(2).sum(dim=2)
+
+ # then gather the closest prior sample for each posterior sample
+ neighbors = torch.argmin(distances, dim=1)
+ latent_subgoals = prior_z_samples[torch.arange(batch_size).long(), neighbors]
+
+ return { "latent_subgoal" : latent_subgoals }
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ with TorchUtils.maybe_no_grad(no_grad=validate):
+ info = super(GL, self).train_on_batch(batch, epoch, validate=validate)
+
+ if self.algo_config.vae.prior.use_categorical:
+ temperature = self.algo_config.vae.prior.categorical_init_temp - epoch * self.algo_config.vae.prior.categorical_temp_anneal_step
+ temperature = max(temperature, self.algo_config.vae.prior.categorical_min_temp)
+ self.nets["goal_network"].set_gumbel_temperature(temperature)
+
+ # batch variables
+ obs = batch["obs"]
+ subgoals = batch["subgoals"] # full subgoal observations
+ target_subgoals = batch["target_subgoals"] # targets for network prediction
+ goal_obs = batch["goal_obs"]
+
+ vae_outputs = self.nets["goal_network"](
+ inputs=subgoals, # encoder takes full subgoals
+ outputs=target_subgoals, # reconstruct target subgoals
+ goals=goal_obs,
+ conditions=obs, # condition on observations
+ )
+ recons_loss = vae_outputs["reconstruction_loss"]
+ kl_loss = vae_outputs["kl_loss"]
+ goal_loss = recons_loss + self.algo_config.vae.kl_weight * kl_loss
+ info["recons_loss"] = recons_loss
+ info["kl_loss"] = kl_loss
+ info["goal_loss"] = goal_loss
+
+ if not self.algo_config.vae.prior.use_categorical:
+ with torch.no_grad():
+ info["encoder_variance"] = torch.exp(vae_outputs["encoder_params"]["logvar"])
+
+ # VAE gradient step
+ if not validate:
+ goal_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["goal_network"],
+ optim=self.optimizers["goal_network"],
+ loss=goal_loss,
+ )
+ info["goal_grad_norms"] = goal_grad_norms
+
+ return info
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ loss_log = super(GL_VAE, self).log_info(info)
+ loss_log["Reconstruction_Loss"] = info["recons_loss"].item()
+ loss_log["KL_Loss"] = info["kl_loss"].item()
+ if self.algo_config.vae.prior.use_categorical:
+ loss_log["Gumbel_Temperature"] = self.nets["goal_network"].get_gumbel_temperature()
+ else:
+ loss_log["Encoder_Variance"] = info["encoder_variance"].mean().item()
+ return loss_log
+
+ def get_subgoal_predictions(self, obs_dict, goal_dict=None):
+ """
+ Takes a batch of observations and predicts a batch of subgoals.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ subgoal prediction (dict): name -> Tensor [batch_size, ...]
+ """
+
+ if self.global_config.algo.latent_subgoal.enabled:
+ # latent subgoals from sampling prior
+ latent_subgoals = self.nets["goal_network"].sample_prior(
+ conditions=obs_dict,
+ goals=goal_dict,
+ )
+
+ return OrderedDict(latent_subgoal=latent_subgoals)
+
+ # sample a single goal from the VAE
+ goals = self.sample_subgoals(obs_dict=obs_dict, goal_dict=goal_dict, num_samples=1)
+ return { k : goals[k][:, 0, ...] for k in goals }
+
+ def sample_subgoals(self, obs_dict, goal_dict=None, num_samples=1):
+ """
+ Sample @num_samples subgoals from the VAE per observation.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ subgoals (dict): name -> Tensor [batch_size, num_samples, ...]
+ """
+
+ # stack observations to get all samples in one forward pass
+ obs_tiled = ObsUtils.repeat_and_stack_observation(obs_dict, n=num_samples)
+ goal_tiled = None
+ if goal_dict is not None:
+ goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_samples)
+
+ # VAE decode expects number of samples explicitly
+ mod = list(obs_tiled.keys())[0]
+ n = obs_tiled[mod].shape[0]
+ # [batch_size * num_samples, ...]
+ goals = self.nets["goal_network"].decode(n=n, conditions=obs_tiled, goals=goal_tiled)
+ # reshape to [batch_size, num_samples, ...]
+ return TensorUtils.reshape_dimensions(goals, begin_axis=0, end_axis=0, target_dims=(-1, num_samples))
+
+
+class ValuePlanner(PlannerAlgo, ValueAlgo):
+ """
+ Base class for all algorithms that are used for planning subgoals
+ based on (1) a @PlannerAlgo that is used to sample candidate subgoals
+ and (2) a @ValueAlgo that is used to select one of the subgoals.
+ """
+ def __init__(
+ self,
+ planner_algo_class,
+ value_algo_class,
+ algo_config,
+ obs_config,
+ global_config,
+ obs_key_shapes,
+ ac_dim,
+ device,
+
+ ):
+ """
+ Args:
+ planner_algo_class (Algo class): algo class for the planner
+
+ value_algo_class (Algo class): algo class for the value network
+
+ algo_config (Config object): instance of Config corresponding to the algo section
+ of the config
+
+ obs_config (Config object): instance of Config corresponding to the observation
+ section of the config
+
+ global_config (Config object); global config
+
+ obs_key_shapes (OrderedDict): dictionary that maps input/output observation keys to shapes
+
+ ac_dim (int): action dimension
+
+ device: torch device
+ """
+ self.algo_config = algo_config
+ self.obs_config = obs_config
+ self.global_config = global_config
+
+ self.ac_dim = ac_dim
+ self.device = device
+
+ self.planner = planner_algo_class(
+ algo_config=algo_config.planner,
+ obs_config=obs_config.planner,
+ global_config=global_config,
+ obs_key_shapes=obs_key_shapes,
+ ac_dim=ac_dim,
+ device=device
+ )
+
+ self.value_net = value_algo_class(
+ algo_config=algo_config.value,
+ obs_config=obs_config.value,
+ global_config=global_config,
+ obs_key_shapes=obs_key_shapes,
+ ac_dim=ac_dim,
+ device=device
+ )
+
+ self.subgoal_shapes = self.planner.subgoal_shapes
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+
+ input_batch["planner"] = self.planner.process_batch_for_training(batch)
+ input_batch["value_net"] = self.value_net.process_batch_for_training(batch)
+
+ # we move to device first before float conversion because image observation modalities will be uint8 -
+ # this minimizes the amount of data transferred to GPU
+ return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ if validate:
+ assert not self.planner.nets.training
+ assert not self.value_net.nets.training
+
+ info = dict(planner=dict(), value_net=dict())
+
+ # train planner
+ info["planner"].update(self.planner.train_on_batch(batch["planner"], epoch, validate=validate))
+
+ # train value network
+ info["value_net"].update(self.value_net.train_on_batch(batch["value_net"], epoch, validate=validate))
+
+ return info
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ loss = 0.
+
+ # planner
+ planner_log = self.planner.log_info(info["planner"])
+ planner_log = dict(("Planner/" + k, v) for k, v in planner_log.items())
+ loss += planner_log["Planner/Loss"]
+
+ # value network
+ value_net_log = self.value_net.log_info(info["value_net"])
+ value_net_log = dict(("ValueNetwork/" + k, v) for k, v in value_net_log.items())
+ loss += value_net_log["ValueNetwork/Loss"]
+ planner_log.update(value_net_log)
+
+ planner_log["Loss"] = loss
+ return planner_log
+
+ def on_epoch_end(self, epoch):
+ """
+ Called at the end of each epoch.
+ """
+ self.planner.on_epoch_end(epoch)
+ self.value_net.on_epoch_end(epoch)
+
+ def set_eval(self):
+ """
+ Prepare networks for evaluation.
+ """
+ self.planner.set_eval()
+ self.value_net.set_eval()
+
+ def set_train(self):
+ """
+ Prepare networks for training.
+ """
+ self.planner.set_train()
+ self.value_net.set_train()
+
+ def serialize(self):
+ """
+ Get dictionary of current model parameters.
+ """
+ return dict(
+ planner=self.planner.serialize(),
+ value_net=self.value_net.serialize(),
+ )
+
+ def deserialize(self, model_dict):
+ """
+ Load model from a checkpoint.
+
+ Args:
+ model_dict (dict): a dictionary saved by self.serialize() that contains
+ the same keys as @self.network_classes
+ """
+ self.planner.deserialize(model_dict["planner"])
+ self.value_net.deserialize(model_dict["value_net"])
+
+ def reset(self):
+ """
+ Reset algo state to prepare for environment rollouts.
+ """
+ self.planner.reset()
+ self.value_net.reset()
+
+ def __repr__(self):
+ """
+ Pretty print algorithm and network description.
+ """
+ msg = str(self.__class__.__name__)
+ import textwrap
+ return msg + "Planner:\n" + textwrap.indent(self.planner.__repr__(), ' ') + \
+ "\n\nValue Network:\n" + textwrap.indent(self.value_net.__repr__(), ' ')
+
+ def get_subgoal_predictions(self, obs_dict, goal_dict=None):
+ """
+ Takes a batch of observations and predicts a batch of subgoals.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ subgoal prediction (dict): name -> Tensor [batch_size, ...]
+ """
+
+ num_samples = self.algo_config.num_samples
+
+ # sample subgoals from the planner (shape: [batch_size, num_samples, ...])
+ subgoals = self.sample_subgoals(obs_dict=obs_dict, goal_dict=goal_dict, num_samples=num_samples)
+
+ # stack subgoals to get all values in one forward pass (shape [batch_size * num_samples, ...])
+ k = list(obs_dict.keys())[0]
+ bsize = obs_dict[k].shape[0]
+ subgoals_tiled = TensorUtils.reshape_dimensions(subgoals, begin_axis=0, end_axis=1, target_dims=(bsize * num_samples,))
+
+ # also repeat goals if necessary
+ goal_tiled = None
+ if len(self.planner.goal_shapes) > 0:
+ goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_samples)
+
+ # evaluate the value of each subgoal
+ subgoal_values = self.value_net.get_state_value(obs_dict=subgoals_tiled, goal_dict=goal_tiled).reshape(-1, num_samples)
+
+ # pick the best subgoal
+ best_index = torch.argmax(subgoal_values, dim=1)
+ best_subgoal = {k: subgoals[k][torch.arange(bsize), best_index] for k in subgoals}
+ return best_subgoal
+
+ def sample_subgoals(self, obs_dict, goal_dict, num_samples=1):
+ """
+ Sample @num_samples subgoals from the planner algo per observation.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ subgoals (dict): name -> Tensor [batch_size, num_samples, ...]
+ """
+ return self.planner.sample_subgoals(obs_dict=obs_dict, goal_dict=goal_dict, num_samples=num_samples)
+
+ def get_state_value(self, obs_dict, goal_dict=None):
+ """
+ Get state value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ return self.value_net.get_state_value(obs_dict=obs_dict, goal_dict=goal_dict)
+
+ def get_state_action_value(self, obs_dict, actions, goal_dict=None):
+ """
+ Get state-action value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ actions (torch.Tensor): action
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ return self.value_net.get_state_action_value(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/hbc.py b/phantom/submodules/phantom-robomimic/robomimic/algo/hbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..543b1fbcf4ced11b9628d506b1972f1123a357b6
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/hbc.py
@@ -0,0 +1,344 @@
+"""
+Implementation of Hierarchical Behavioral Cloning, where
+a planner model outputs subgoals (future observations), and
+an actor model is conditioned on the subgoals to try and
+reach them. Largely based on the Generalization Through Imitation (GTI)
+paper (see https://arxiv.org/abs/2003.06085).
+"""
+import textwrap
+import numpy as np
+from collections import OrderedDict
+from copy import deepcopy
+
+import torch
+
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.obs_utils as ObsUtils
+from robomimic.config.config import Config
+from robomimic.algo import register_algo_factory_func, algo_name_to_factory_func, HierarchicalAlgo, GL_VAE
+
+
+@register_algo_factory_func("hbc")
+def algo_config_to_class(algo_config):
+ """
+ Maps algo config to the HBC algo class to instantiate, along with additional algo kwargs.
+
+ Args:
+ algo_config (Config instance): algo config
+
+ Returns:
+ algo_class: subclass of Algo
+ algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
+ """
+ pol_cls, _ = algo_name_to_factory_func("bc")(algo_config.actor)
+ plan_cls, _ = algo_name_to_factory_func("gl")(algo_config.planner)
+ return HBC, dict(policy_algo_class=pol_cls, planner_algo_class=plan_cls)
+
+
+class HBC(HierarchicalAlgo):
+ """
+ Default HBC training, largely based on https://arxiv.org/abs/2003.06085
+ """
+ def __init__(
+ self,
+ planner_algo_class,
+ policy_algo_class,
+ algo_config,
+ obs_config,
+ global_config,
+ obs_key_shapes,
+ ac_dim,
+ device,
+ ):
+ """
+ Args:
+ planner_algo_class (Algo class): algo class for the planner
+
+ policy_algo_class (Algo class): algo class for the policy
+
+ algo_config (Config object): instance of Config corresponding to the algo section
+ of the config
+
+ obs_config (Config object): instance of Config corresponding to the observation
+ section of the config
+
+ global_config (Config object): global training config
+
+ obs_key_shapes (dict): dictionary that maps input/output observation keys to shapes
+
+ ac_dim (int): action dimension
+
+ device: torch device
+ """
+ self.algo_config = algo_config
+ self.obs_config = obs_config
+ self.global_config = global_config
+
+ self.ac_dim = ac_dim
+ self.device = device
+
+ self._subgoal_step_count = 0 # current step count for deciding when to update subgoal
+ self._current_subgoal = None # latest subgoal
+ self._subgoal_update_interval = self.algo_config.subgoal_update_interval # subgoal update frequency
+ self._subgoal_horizon = self.algo_config.planner.subgoal_horizon
+ self._actor_horizon = self.algo_config.actor.rnn.horizon
+
+ self._algo_mode = self.algo_config.mode
+ assert self._algo_mode in ["separate", "cascade"]
+
+ self.planner = planner_algo_class(
+ algo_config=algo_config.planner,
+ obs_config=obs_config.planner,
+ global_config=global_config,
+ obs_key_shapes=obs_key_shapes,
+ ac_dim=ac_dim,
+ device=device
+ )
+
+ # goal-conditional actor follows goals set by the planner
+ self.actor_goal_shapes = self.planner.subgoal_shapes
+ if self.algo_config.latent_subgoal.enabled:
+ assert planner_algo_class == GL_VAE # only VAE supported for now
+ self.actor_goal_shapes = OrderedDict(latent_subgoal=(self.planner.algo_config.vae.latent_dim,))
+
+ # only for the actor: override goal modalities and shapes to match the subgoal set by the planner
+ actor_obs_key_shapes = deepcopy(obs_key_shapes)
+ # make sure we are not modifying existing observation key shapes
+ for k in self.actor_goal_shapes:
+ if k in actor_obs_key_shapes:
+ assert actor_obs_key_shapes[k] == self.actor_goal_shapes[k]
+ actor_obs_key_shapes.update(self.actor_goal_shapes)
+
+ goal_obs_keys = {obs_modality: [] for obs_modality in ObsUtils.OBS_MODALITY_CLASSES.keys()}
+ for k in self.actor_goal_shapes.keys():
+ goal_obs_keys[ObsUtils.OBS_KEYS_TO_MODALITIES[k]].append(k)
+
+ actor_obs_config = deepcopy(obs_config.actor)
+ with actor_obs_config.unlocked():
+ actor_obs_config["goal"] = Config(**goal_obs_keys)
+
+ self.actor = policy_algo_class(
+ algo_config=algo_config.actor,
+ obs_config=actor_obs_config,
+ global_config=global_config,
+ obs_key_shapes=actor_obs_key_shapes,
+ ac_dim=ac_dim,
+ device=device,
+ )
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+
+ input_batch["planner"] = self.planner.process_batch_for_training(batch)
+ input_batch["actor"] = self.actor.process_batch_for_training(batch)
+
+ if self.algo_config.actor_use_random_subgoals:
+ # optionally use randomly sampled step between [1, seq_length] as policy goal
+ policy_subgoal_indices = torch.randint(
+ low=0, high=self.global_config.train.seq_length, size=(batch["actions"].shape[0],))
+ goal_obs = TensorUtils.gather_sequence(batch["next_obs"], policy_subgoal_indices)
+ goal_obs = TensorUtils.to_float(TensorUtils.to_device(goal_obs, self.device))
+ input_batch["actor"]["goal_obs"] = \
+ self.planner.get_actor_goal_for_training_from_processed_batch(
+ goal_obs,
+ use_latent_subgoals=self.algo_config.latent_subgoal.enabled,
+ use_prior_correction=self.algo_config.latent_subgoal.prior_correction.enabled,
+ num_prior_samples=self.algo_config.latent_subgoal.prior_correction.num_samples,
+ )
+ else:
+ # otherwise, use planner subgoal target as goal for the policy
+ input_batch["actor"]["goal_obs"] = \
+ self.planner.get_actor_goal_for_training_from_processed_batch(
+ input_batch["planner"],
+ use_latent_subgoals=self.algo_config.latent_subgoal.enabled,
+ use_prior_correction=self.algo_config.latent_subgoal.prior_correction.enabled,
+ num_prior_samples=self.algo_config.latent_subgoal.prior_correction.num_samples,
+ )
+
+ # we move to device first before float conversion because image observation modalities will be uint8 -
+ # this minimizes the amount of data transferred to GPU
+ return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ info = dict(planner=dict(), actor=dict())
+ # train planner
+ info["planner"].update(self.planner.train_on_batch(batch["planner"], epoch, validate=validate))
+
+ # train actor
+ if self._algo_mode == "separate":
+ # train low-level actor by getting subgoals from the dataset
+ info["actor"].update(self.actor.train_on_batch(batch["actor"], epoch, validate=validate))
+
+ elif self._algo_mode == "cascade":
+ # get predictions from the planner
+ with torch.no_grad():
+ batch["actor"]["goal_obs"] = self.planner.get_subgoal_predictions(
+ obs_dict=batch["planner"]["obs"], goal_dict=batch["planner"]["goal_obs"])
+
+ # train actor with the predicted goal
+ info["actor"].update(self.actor.train_on_batch(batch["actor"], epoch, validate=validate))
+
+ else:
+ raise NotImplementedError("algo mode {} is not implemented".format(self._algo_mode))
+
+ return info
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ planner_log = dict()
+ actor_log = dict()
+ loss = 0.
+
+ planner_log = self.planner.log_info(info["planner"])
+ planner_log = dict(("Planner/" + k, v) for k, v in planner_log.items())
+ loss += planner_log["Planner/Loss"]
+
+ actor_log = self.actor.log_info(info["actor"])
+ actor_log = dict(("Actor/" + k, v) for k, v in actor_log.items())
+ loss += actor_log["Actor/Loss"]
+
+ planner_log.update(actor_log)
+ planner_log["Loss"] = loss
+ return planner_log
+
+ def on_epoch_end(self, epoch):
+ """
+ Called at the end of each epoch.
+ """
+ self.planner.on_epoch_end(epoch)
+ self.actor.on_epoch_end(epoch)
+
+ def set_eval(self):
+ """
+ Prepare networks for evaluation.
+ """
+ self.planner.set_eval()
+ self.actor.set_eval()
+
+ def set_train(self):
+ """
+ Prepare networks for training.
+ """
+ self.planner.set_train()
+ self.actor.set_train()
+
+ def serialize(self):
+ """
+ Get dictionary of current model parameters.
+ """
+ return dict(
+ planner=self.planner.serialize(),
+ actor=self.actor.serialize(),
+ )
+
+ def deserialize(self, model_dict):
+ """
+ Load model from a checkpoint.
+
+ Args:
+ model_dict (dict): a dictionary saved by self.serialize() that contains
+ the same keys as @self.network_classes
+ """
+ self.actor.deserialize(model_dict["actor"])
+ self.planner.deserialize(model_dict["planner"])
+
+ @property
+ def current_subgoal(self):
+ """
+ Return the current subgoal (at rollout time) with shape (batch, ...)
+ """
+ return { k : self._current_subgoal[k].clone() for k in self._current_subgoal }
+
+ @current_subgoal.setter
+ def current_subgoal(self, sg):
+ """
+ Sets the current subgoal being used by the actor.
+ """
+ for k, v in sg.items():
+ if not self.algo_config.latent_subgoal.enabled:
+ # subgoal should only match subgoal shapes if not using latent subgoals
+ assert list(v.shape[1:]) == list(self.planner.subgoal_shapes[k])
+ # subgoal shapes should always match actor goal shapes
+ assert list(v.shape[1:]) == list(self.actor_goal_shapes[k])
+ self._current_subgoal = { k : sg[k].clone() for k in sg }
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ if self._current_subgoal is None or self._subgoal_step_count % self._subgoal_update_interval == 0:
+ # update current subgoal
+ self.current_subgoal = self.planner.get_subgoal_predictions(obs_dict=obs_dict, goal_dict=goal_dict)
+
+ action = self.actor.get_action(obs_dict=obs_dict, goal_dict=self.current_subgoal)
+ self._subgoal_step_count += 1
+ return action
+
+ def reset(self):
+ """
+ Reset algo state to prepare for environment rollouts.
+ """
+ self._current_subgoal = None
+ self._subgoal_step_count = 0
+ self.planner.reset()
+ self.actor.reset()
+
+ def __repr__(self):
+ """
+ Pretty print algorithm and network description.
+ """
+ msg = str(self.__class__.__name__)
+ msg += "(subgoal_horizon={}, actor_horizon={}, subgoal_update_interval={}, mode={}, " \
+ "actor_use_random_subgoals={})\n".format(
+ self._subgoal_horizon,
+ self._actor_horizon,
+ self._subgoal_update_interval,
+ self._algo_mode,
+ self.algo_config.actor_use_random_subgoals
+ )
+ return msg + "Planner:\n" + textwrap.indent(self.planner.__repr__(), ' ') + \
+ "\n\nPolicy:\n" + textwrap.indent(self.actor.__repr__(), ' ')
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/iql.py b/phantom/submodules/phantom-robomimic/robomimic/algo/iql.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde522b2292e6140b5ce4e3120ad0c83e4064fff
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/iql.py
@@ -0,0 +1,428 @@
+"""
+Implementation of Implicit Q-Learning (IQL).
+Based off of https://github.com/rail-berkeley/rlkit/blob/master/rlkit/torch/sac/iql_trainer.py.
+(Paper - https://arxiv.org/abs/2110.06169).
+"""
+import numpy as np
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import robomimic.models.policy_nets as PolicyNets
+import robomimic.models.value_nets as ValueNets
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+from robomimic.algo import register_algo_factory_func, ValueAlgo, PolicyAlgo
+
+
+@register_algo_factory_func("iql")
+def algo_config_to_class(algo_config):
+ """
+ Maps algo config to the IQL algo class to instantiate, along with additional algo kwargs.
+
+ Args:
+ algo_config (Config instance): algo config
+
+ Returns:
+ algo_class: subclass of Algo
+ algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
+ """
+ return IQL, {}
+
+
+class IQL(PolicyAlgo, ValueAlgo):
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+
+ Networks for this algo: critic (potentially ensemble), actor, value function
+ """
+
+ # Create nets
+ self.nets = nn.ModuleDict()
+
+ # Assemble args to pass to actor
+ actor_args = dict(self.algo_config.actor.net.common)
+
+ # Add network-specific args and define network class
+ if self.algo_config.actor.net.type == "gaussian":
+ actor_cls = PolicyNets.GaussianActorNetwork
+ actor_args.update(dict(self.algo_config.actor.net.gaussian))
+ elif self.algo_config.actor.net.type == "gmm":
+ actor_cls = PolicyNets.GMMActorNetwork
+ actor_args.update(dict(self.algo_config.actor.net.gmm))
+ else:
+ # Unsupported actor type!
+ raise ValueError(f"Unsupported actor requested. "
+ f"Requested: {self.algo_config.actor.net.type}, "
+ f"valid options are: {['gaussian', 'gmm']}")
+
+ # Actor
+ self.nets["actor"] = actor_cls(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.actor.layer_dims,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ **actor_args,
+ )
+
+ # Critics
+ self.nets["critic"] = nn.ModuleList()
+ self.nets["critic_target"] = nn.ModuleList()
+ for _ in range(self.algo_config.critic.ensemble.n):
+ for net_list in (self.nets["critic"], self.nets["critic_target"]):
+ critic = ValueNets.ActionValueNetwork(
+ obs_shapes=self.obs_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.critic.layer_dims,
+ goal_shapes=self.goal_shapes,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+ net_list.append(critic)
+
+ # Value function network
+ self.nets["vf"] = ValueNets.ValueNetwork(
+ obs_shapes=self.obs_shapes,
+ mlp_layer_dims=self.algo_config.critic.layer_dims,
+ goal_shapes=self.goal_shapes,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ # Send networks to appropriate device
+ self.nets = self.nets.float().to(self.device)
+
+ # sync target networks at beginning of training
+ with torch.no_grad():
+ for critic, critic_target in zip(self.nets["critic"], self.nets["critic_target"]):
+ TorchUtils.hard_update(
+ source=critic,
+ target=critic_target,
+ )
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out relevant info and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+
+ input_batch = dict()
+
+ # remove temporal batches for all
+ input_batch["obs"] = {k: batch["obs"][k][:, 0, :] for k in batch["obs"]}
+ input_batch["next_obs"] = {k: batch["next_obs"][k][:, 0, :] for k in batch["next_obs"]}
+ input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
+ input_batch["actions"] = batch["actions"][:, 0, :]
+ input_batch["dones"] = batch["dones"][:, 0]
+ input_batch["rewards"] = batch["rewards"][:, 0]
+
+ return TensorUtils.to_device(TensorUtils.to_float(input_batch), self.device)
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ info = OrderedDict()
+
+ # Set the correct context for this training step
+ with TorchUtils.maybe_no_grad(no_grad=validate):
+ # Always run super call first
+ info = super().train_on_batch(batch, epoch, validate=validate)
+
+ # Compute loss for critic(s)
+ critic_losses, vf_loss, critic_info = self._compute_critic_loss(batch)
+ # Compute loss for actor
+ actor_loss, actor_info = self._compute_actor_loss(batch, critic_info)
+
+ if not validate:
+ # Critic update
+ self._update_critic(critic_losses, vf_loss)
+
+ # Actor update
+ self._update_actor(actor_loss)
+
+ # Update info
+ info.update(actor_info)
+ info.update(critic_info)
+
+ # Return stats
+ return info
+
+ def _compute_critic_loss(self, batch):
+ """
+ Helper function for computing Q and V losses. Called by @train_on_batch
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ Returns:
+ critic_losses (list): list of critic (Q function) losses
+ vf_loss (torch.Tensor): value function loss
+ info (dict): dictionary of Q / V predictions and losses
+ """
+ info = OrderedDict()
+
+ # get batch values
+ obs = batch["obs"]
+ actions = batch["actions"]
+ next_obs = batch["next_obs"]
+ goal_obs = batch["goal_obs"]
+ rewards = torch.unsqueeze(batch["rewards"], 1)
+ dones = torch.unsqueeze(batch["dones"], 1)
+
+ # Q predictions
+ pred_qs = [critic(obs_dict=obs, acts=actions, goal_dict=goal_obs)
+ for critic in self.nets["critic"]]
+
+ info["critic/critic1_pred"] = pred_qs[0].mean()
+
+ # Q target values
+ target_vf_pred = self.nets["vf"](obs_dict=next_obs, goal_dict=goal_obs).detach()
+ q_target = rewards + (1. - dones) * self.algo_config.discount * target_vf_pred
+ q_target = q_target.detach()
+
+ # Q losses
+ critic_losses = []
+ td_loss_fcn = nn.SmoothL1Loss() if self.algo_config.critic.use_huber else nn.MSELoss()
+ for (i, q_pred) in enumerate(pred_qs):
+ # Calculate td error loss
+ td_loss = td_loss_fcn(q_pred, q_target)
+ info[f"critic/critic{i+1}_loss"] = td_loss
+ critic_losses.append(td_loss)
+
+ # V predictions
+ pred_qs = [critic(obs_dict=obs, acts=actions, goal_dict=goal_obs)
+ for critic in self.nets["critic_target"]]
+ q_pred, _ = torch.cat(pred_qs, dim=1).min(dim=1, keepdim=True)
+ q_pred = q_pred.detach()
+ vf_pred = self.nets["vf"](obs)
+
+ # V losses: expectile regression. see section 4.1 in https://arxiv.org/pdf/2110.06169.pdf
+ vf_err = vf_pred - q_pred
+ vf_sign = (vf_err > 0).float()
+ vf_weight = (1 - vf_sign) * self.algo_config.vf_quantile + vf_sign * (1 - self.algo_config.vf_quantile)
+ vf_loss = (vf_weight * (vf_err ** 2)).mean()
+
+ # update logs for V loss
+ info["vf/q_pred"] = q_pred
+ info["vf/v_pred"] = vf_pred
+ info["vf/v_loss"] = vf_loss
+
+ # Return stats
+ return critic_losses, vf_loss, info
+
+ def _update_critic(self, critic_losses, vf_loss):
+ """
+ Helper function for updating critic and vf networks. Called by @train_on_batch
+
+ Args:
+ critic_losses (list): list of critic (Q function) losses
+ vf_loss (torch.Tensor): value function loss
+ """
+
+ # update ensemble of critics
+ for (critic_loss, critic, critic_target, optimizer) in zip(
+ critic_losses, self.nets["critic"], self.nets["critic_target"], self.optimizers["critic"]
+ ):
+ TorchUtils.backprop_for_loss(
+ net=critic,
+ optim=optimizer,
+ loss=critic_loss,
+ max_grad_norm=self.algo_config.critic.max_gradient_norm,
+ retain_graph=False,
+ )
+
+ # update target network
+ with torch.no_grad():
+ TorchUtils.soft_update(source=critic, target=critic_target, tau=self.algo_config.target_tau)
+
+ # update V function network
+ TorchUtils.backprop_for_loss(
+ net=self.nets["vf"],
+ optim=self.optimizers["vf"],
+ loss=vf_loss,
+ max_grad_norm=self.algo_config.critic.max_gradient_norm,
+ retain_graph=False,
+ )
+
+ def _compute_actor_loss(self, batch, critic_info):
+ """
+ Helper function for computing actor loss. Called by @train_on_batch
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ critic_info (dict): dictionary containing Q and V function predictions,
+ to be used for computing advantage estimates
+
+ Returns:
+ actor_loss (torch.Tensor): actor loss
+ info (dict): dictionary of actor losses, log_probs, advantages, and weights
+ """
+ info = OrderedDict()
+
+ # compute log probability of batch actions
+ dist = self.nets["actor"].forward_train(obs_dict=batch["obs"], goal_dict=batch["goal_obs"])
+ log_prob = dist.log_prob(batch["actions"])
+
+ info["actor/log_prob"] = log_prob.mean()
+
+ # compute advantage estimate
+ q_pred = critic_info["vf/q_pred"]
+ v_pred = critic_info["vf/v_pred"]
+ adv = q_pred - v_pred
+
+ # compute weights
+ weights = self._get_adv_weights(adv)
+
+ # compute advantage weighted actor loss. disable gradients through weights
+ actor_loss = (-log_prob * weights.detach()).mean()
+
+ info["actor/loss"] = actor_loss
+
+ # log adv-related values
+ info["adv/adv"] = adv
+ info["adv/adv_weight"] = weights
+
+ # Return stats
+ return actor_loss, info
+
+ def _update_actor(self, actor_loss):
+ """
+ Helper function for updating actor network. Called by @train_on_batch
+
+ Args:
+ actor_loss (torch.Tensor): actor loss
+ """
+
+ TorchUtils.backprop_for_loss(
+ net=self.nets["actor"],
+ optim=self.optimizers["actor"],
+ loss=actor_loss,
+ max_grad_norm=self.algo_config.actor.max_gradient_norm,
+ )
+
+ def _get_adv_weights(self, adv):
+ """
+ Helper function for computing advantage weights. Called by @_compute_actor_loss
+
+ Args:
+ adv (torch.Tensor): raw advantage estimates
+
+ Returns:
+ weights (torch.Tensor): weights computed based on advantage estimates,
+ in shape (B,) where B is batch size
+ """
+
+ # clip raw advantage values
+ if self.algo_config.adv.clip_adv_value is not None:
+ adv = adv.clamp(max=self.algo_config.adv.clip_adv_value)
+
+ # compute weights based on advantage values
+ beta = self.algo_config.adv.beta # temprature factor
+ weights = torch.exp(adv / beta)
+
+ # clip final weights
+ if self.algo_config.adv.use_final_clip is True:
+ weights = weights.clamp(-100.0, 100.0)
+
+ # reshape from (B, 1) to (B,)
+ return weights[:, 0]
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ log = OrderedDict()
+
+ log["actor/log_prob"] = info["actor/log_prob"].item()
+ log["actor/loss"] = info["actor/loss"].item()
+
+ log["critic/critic1_pred"] = info["critic/critic1_pred"].item()
+ log["critic/critic1_loss"] = info["critic/critic1_loss"].item()
+
+ log["vf/v_loss"] = info["vf/v_loss"].item()
+
+ self._log_data_attributes(log, info, "vf/q_pred")
+ self._log_data_attributes(log, info, "vf/v_pred")
+ self._log_data_attributes(log, info, "adv/adv")
+ self._log_data_attributes(log, info, "adv/adv_weight")
+
+ return log
+
+ def _log_data_attributes(self, log, info, key):
+ """
+ Helper function for logging statistics. Moodifies log in-place
+
+ Args:
+ log (dict): existing log dictionary
+ log (dict): existing dictionary of tensors containing raw stats
+ key (str): key to log
+ """
+ log[key + "/max"] = info[key].max().item()
+ log[key + "/min"] = info[key].min().item()
+ log[key + "/mean"] = info[key].mean().item()
+ log[key + "/std"] = info[key].std().item()
+
+ def on_epoch_end(self, epoch):
+ """
+ Called at the end of each epoch.
+ """
+
+ # LR scheduling updates
+ for lr_sc in self.lr_schedulers["critic"]:
+ if lr_sc is not None:
+ lr_sc.step()
+
+ if self.lr_schedulers["vf"] is not None:
+ self.lr_schedulers["vf"].step()
+
+ if self.lr_schedulers["actor"] is not None:
+ self.lr_schedulers["actor"].step()
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ assert not self.nets.training
+
+ return self.nets["actor"](obs_dict=obs_dict, goal_dict=goal_dict)
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/iris.py b/phantom/submodules/phantom-robomimic/robomimic/algo/iris.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b441470c796749f92682ecf2b38a48e0bb3ada5
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/iris.py
@@ -0,0 +1,183 @@
+"""
+Implementation of IRIS (https://arxiv.org/abs/1911.05321).
+"""
+import numpy as np
+from collections import OrderedDict
+from copy import deepcopy
+
+import torch
+
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.obs_utils as ObsUtils
+from robomimic.config.config import Config
+from robomimic.algo import register_algo_factory_func, algo_name_to_factory_func, HBC, ValuePlanner, ValueAlgo, GL_VAE
+
+
+@register_algo_factory_func("iris")
+def algo_config_to_class(algo_config):
+ """
+ Maps algo config to the IRIS algo class to instantiate, along with additional algo kwargs.
+
+ Args:
+ algo_config (Config instance): algo config
+
+ Returns:
+ algo_class: subclass of Algo
+ algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
+ """
+ pol_cls, _ = algo_name_to_factory_func("bc")(algo_config.actor)
+ plan_cls, _ = algo_name_to_factory_func("gl")(algo_config.value_planner.planner)
+ value_cls, _ = algo_name_to_factory_func("bcq")(algo_config.value_planner.value)
+ return IRIS, dict(policy_algo_class=pol_cls, planner_algo_class=plan_cls, value_algo_class=value_cls)
+
+
+class IRIS(HBC, ValueAlgo):
+ """
+ Implementation of IRIS (https://arxiv.org/abs/1911.05321).
+ """
+ def __init__(
+ self,
+ planner_algo_class,
+ value_algo_class,
+ policy_algo_class,
+ algo_config,
+ obs_config,
+ global_config,
+ obs_key_shapes,
+ ac_dim,
+ device,
+ ):
+ """
+ Args:
+ planner_algo_class (Algo class): algo class for the planner
+
+ policy_algo_class (Algo class): algo class for the policy
+
+ algo_config (Config object): instance of Config corresponding to the algo section
+ of the config
+
+ obs_config (Config object): instance of Config corresponding to the observation
+ section of the config
+
+ global_config (Config object): global training config
+
+ obs_key_shapes (OrderedDict): dictionary that maps input/output observation keys to shapes
+
+ ac_dim (int): action dimension
+
+ device: torch device
+ """
+ self.algo_config = algo_config
+ self.obs_config = obs_config
+ self.global_config = global_config
+
+ self.ac_dim = ac_dim
+ self.device = device
+
+ self._subgoal_step_count = 0 # current step count for deciding when to update subgoal
+ self._current_subgoal = None # latest subgoal
+ self._subgoal_update_interval = self.algo_config.subgoal_update_interval # subgoal update frequency
+ self._subgoal_horizon = self.algo_config.value_planner.planner.subgoal_horizon
+ self._actor_horizon = self.algo_config.actor.rnn.horizon
+
+ self._algo_mode = self.algo_config.mode
+ assert self._algo_mode in ["separate", "cascade"]
+
+ self.planner = ValuePlanner(
+ planner_algo_class=planner_algo_class,
+ value_algo_class=value_algo_class,
+ algo_config=algo_config.value_planner,
+ obs_config=obs_config.value_planner,
+ global_config=global_config,
+ obs_key_shapes=obs_key_shapes,
+ ac_dim=ac_dim,
+ device=device
+ )
+
+ self.actor_goal_shapes = self.planner.subgoal_shapes
+ assert not algo_config.latent_subgoal.enabled, "IRIS does not support latent subgoals"
+
+ # only for the actor: override goal modalities and shapes to match the subgoal set by the planner
+ actor_obs_key_shapes = deepcopy(obs_key_shapes)
+ # make sure we are not modifying existing observation key shapes
+ for k in self.actor_goal_shapes:
+ if k in actor_obs_key_shapes:
+ assert actor_obs_key_shapes[k] == self.actor_goal_shapes[k]
+ actor_obs_key_shapes.update(self.actor_goal_shapes)
+
+ goal_modalities = {obs_modality: [] for obs_modality in ObsUtils.OBS_MODALITY_CLASSES.keys()}
+ for k in self.actor_goal_shapes.keys():
+ goal_modalities[ObsUtils.OBS_KEYS_TO_MODALITIES[k]].append(k)
+
+ actor_obs_config = deepcopy(obs_config.actor)
+ with actor_obs_config.unlocked():
+ actor_obs_config["goal"] = Config(**goal_modalities)
+
+ self.actor = policy_algo_class(
+ algo_config=algo_config.actor,
+ obs_config=actor_obs_config,
+ global_config=global_config,
+ obs_key_shapes=actor_obs_key_shapes,
+ ac_dim=ac_dim,
+ device=device
+ )
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+
+ input_batch["planner"] = self.planner.process_batch_for_training(batch)
+ input_batch["actor"] = self.actor.process_batch_for_training(batch)
+
+ if self.algo_config.actor_use_random_subgoals:
+ # optionally use randomly sampled step between [1, seq_length] as policy goal
+ policy_subgoal_indices = torch.randint(
+ low=0, high=self.global_config.train.seq_length, size=(batch["actions"].shape[0],))
+ goal_obs = TensorUtils.gather_sequence(batch["next_obs"], policy_subgoal_indices)
+ goal_obs = TensorUtils.to_float(TensorUtils.to_device(goal_obs, self.device))
+ input_batch["actor"]["goal_obs"] = goal_obs
+ else:
+ # otherwise, use planner subgoal target as goal for the policy
+ input_batch["actor"]["goal_obs"] = input_batch["planner"]["planner"]["target_subgoals"]
+
+ # we move to device first before float conversion because image observation modalities will be uint8 -
+ # this minimizes the amount of data transferred to GPU
+ return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
+
+ def get_state_value(self, obs_dict, goal_dict=None):
+ """
+ Get state value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ return self.planner.get_state_value(obs_dict=obs_dict, goal_dict=goal_dict)
+
+ def get_state_action_value(self, obs_dict, actions, goal_dict=None):
+ """
+ Get state-action value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ actions (torch.Tensor): action
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ return self.planner.get_state_action_value(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/algo/td3_bc.py b/phantom/submodules/phantom-robomimic/robomimic/algo/td3_bc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e324c54a1614c1c01b3efdb1def9c8f6f11b2c70
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/algo/td3_bc.py
@@ -0,0 +1,567 @@
+"""
+Implementation of TD3-BC.
+Based on https://github.com/sfujim/TD3_BC
+(Paper - https://arxiv.org/abs/1812.02900).
+
+Note that several parts are exactly the same as the BCQ implementation,
+such as @_create_critics, @process_batch_for_training, and
+@_train_critic_on_batch. They are replicated here (instead of subclassing
+from the BCQ algo class) to be explicit and have implementation details
+self-contained in this file.
+"""
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import robomimic.models.obs_nets as ObsNets
+import robomimic.models.policy_nets as PolicyNets
+import robomimic.models.value_nets as ValueNets
+import robomimic.models.vae_nets as VAENets
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.loss_utils as LossUtils
+
+from robomimic.algo import register_algo_factory_func, PolicyAlgo, ValueAlgo
+
+
+@register_algo_factory_func("td3_bc")
+def algo_config_to_class(algo_config):
+ """
+ Maps algo config to the TD3_BC algo class to instantiate, along with additional algo kwargs.
+
+ Args:
+ algo_config (Config instance): algo config
+
+ Returns:
+ algo_class: subclass of Algo
+ algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
+ """
+ # only one variant of TD3_BC for now
+ return TD3_BC, {}
+
+
+class TD3_BC(PolicyAlgo, ValueAlgo):
+ """
+ Default TD3_BC training, based on https://arxiv.org/abs/2106.06860 and
+ https://github.com/sfujim/TD3_BC.
+ """
+ def __init__(self, **kwargs):
+ PolicyAlgo.__init__(self, **kwargs)
+
+ # save the discount factor - it may be overriden later
+ self.set_discount(self.algo_config.discount)
+
+ # initialize actor update counter. This is used to train the actor at a lower freq than critic
+ self.actor_update_counter = 0
+
+ def _create_networks(self):
+ """
+ Creates networks and places them into @self.nets.
+ """
+ self.nets = nn.ModuleDict()
+
+ self._create_critics()
+ self._create_actor()
+
+ # sync target networks at beginning of training
+ with torch.no_grad():
+ for critic_ind in range(len(self.nets["critic"])):
+ TorchUtils.hard_update(
+ source=self.nets["critic"][critic_ind],
+ target=self.nets["critic_target"][critic_ind],
+ )
+
+ TorchUtils.hard_update(
+ source=self.nets["actor"],
+ target=self.nets["actor_target"],
+ )
+
+ self.nets = self.nets.float().to(self.device)
+
+ def _create_critics(self):
+ """
+ Called in @_create_networks to make critic networks.
+
+ Exactly the same as BCQ.
+ """
+ critic_class = ValueNets.ActionValueNetwork
+ critic_args = dict(
+ obs_shapes=self.obs_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.critic.layer_dims,
+ value_bounds=self.algo_config.critic.value_bounds,
+ goal_shapes=self.goal_shapes,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ # Q network ensemble and target ensemble
+ self.nets["critic"] = nn.ModuleList()
+ self.nets["critic_target"] = nn.ModuleList()
+ for _ in range(self.algo_config.critic.ensemble.n):
+ critic = critic_class(**critic_args)
+ self.nets["critic"].append(critic)
+
+ critic_target = critic_class(**critic_args)
+ self.nets["critic_target"].append(critic_target)
+
+ def _create_actor(self):
+ """
+ Called in @_create_networks to make actor network.
+ """
+ actor_class = PolicyNets.ActorNetwork
+ actor_args = dict(
+ obs_shapes=self.obs_shapes,
+ goal_shapes=self.goal_shapes,
+ ac_dim=self.ac_dim,
+ mlp_layer_dims=self.algo_config.actor.layer_dims,
+ encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
+ )
+
+ self.nets["actor"] = actor_class(**actor_args)
+ self.nets["actor_target"] = actor_class(**actor_args)
+
+ def _check_epoch(self, net_name, epoch):
+ """
+ Helper function to check whether backprop should happen this epoch.
+
+ Args:
+ net_name (str): name of network in @self.nets and @self.optim_params
+ epoch (int): epoch number
+ """
+ epoch_start_check = (self.optim_params[net_name]["start_epoch"] == -1) or (epoch >= self.optim_params[net_name]["start_epoch"])
+ epoch_end_check = (self.optim_params[net_name]["end_epoch"] == -1) or (epoch < self.optim_params[net_name]["end_epoch"])
+ return (epoch_start_check and epoch_end_check)
+
+ def set_discount(self, discount):
+ """
+ Useful function to modify discount factor if necessary (e.g. for n-step returns).
+ """
+ self.discount = discount
+
+ def process_batch_for_training(self, batch):
+ """
+ Processes input batch from a data loader to filter out
+ relevant information and prepare the batch for training.
+
+ Exactly the same as BCQ.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader
+
+ Returns:
+ input_batch (dict): processed and filtered batch that
+ will be used for training
+ """
+ input_batch = dict()
+
+ # n-step returns (default is 1)
+ n_step = self.algo_config.n_step
+ assert batch["actions"].shape[1] >= n_step
+
+ # remove temporal batches for all
+ input_batch["obs"] = {k: batch["obs"][k][:, 0, :] for k in batch["obs"]}
+ input_batch["next_obs"] = {k: batch["next_obs"][k][:, n_step - 1, :] for k in batch["next_obs"]}
+ input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
+ input_batch["actions"] = batch["actions"][:, 0, :]
+
+ # note: ensure scalar signals (rewards, done) retain last dimension of 1 to be compatible with model outputs
+
+ # single timestep reward is discounted sum of intermediate rewards in sequence
+ reward_seq = batch["rewards"][:, :n_step]
+ discounts = torch.pow(self.algo_config.discount, torch.arange(n_step).float()).unsqueeze(0)
+ input_batch["rewards"] = (reward_seq * discounts).sum(dim=1).unsqueeze(1)
+
+ # discount rate will be gamma^N for computing n-step returns
+ new_discount = (self.algo_config.discount ** n_step)
+ self.set_discount(new_discount)
+
+ # consider this n-step seqeunce done if any intermediate dones are present
+ done_seq = batch["dones"][:, :n_step]
+ input_batch["dones"] = (done_seq.sum(dim=1) > 0).float().unsqueeze(1)
+
+ if self.algo_config.infinite_horizon:
+ # scale terminal rewards by 1 / (1 - gamma) for infinite horizon MDPs
+ done_inds = input_batch["dones"].round().long().nonzero(as_tuple=False)[:, 0]
+ if done_inds.shape[0] > 0:
+ input_batch["rewards"][done_inds] = input_batch["rewards"][done_inds] * (1. / (1. - self.discount))
+
+ # we move to device first before float conversion because image observation modalities will be uint8 -
+ # this minimizes the amount of data transferred to GPU
+ return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
+
+ def _train_critic_on_batch(self, batch, epoch, no_backprop=False):
+ """
+ A modular helper function that can be overridden in case
+ subclasses would like to modify training behavior for the
+ critics.
+
+ Exactly the same as BCQ (except for removal of @action_sampler_outputs and @critic_outputs)
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ no_backprop (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ info = OrderedDict()
+
+ # batch variables
+ s_batch = batch["obs"]
+ a_batch = batch["actions"]
+ r_batch = batch["rewards"]
+ ns_batch = batch["next_obs"]
+ goal_s_batch = batch["goal_obs"]
+
+ # 1 if not done, 0 otherwise
+ done_mask_batch = 1. - batch["dones"]
+ info["done_masks"] = done_mask_batch
+
+ # Bellman backup for Q-targets
+ q_targets = self._get_target_values(
+ next_states=ns_batch,
+ goal_states=goal_s_batch,
+ rewards=r_batch,
+ dones=done_mask_batch,
+ )
+ info["critic/q_targets"] = q_targets
+
+ # Train all critics using this set of targets for regression
+ for critic_ind, critic in enumerate(self.nets["critic"]):
+ critic_loss = self._compute_critic_loss(
+ critic=critic,
+ states=s_batch,
+ actions=a_batch,
+ goal_states=goal_s_batch,
+ q_targets=q_targets,
+ )
+ info["critic/critic{}_loss".format(critic_ind + 1)] = critic_loss
+
+ if not no_backprop:
+ critic_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["critic"][critic_ind],
+ optim=self.optimizers["critic"][critic_ind],
+ loss=critic_loss,
+ max_grad_norm=self.algo_config.critic.max_gradient_norm,
+ )
+ info["critic/critic{}_grad_norms".format(critic_ind + 1)] = critic_grad_norms
+
+ return info
+
+ def _train_actor_on_batch(self, batch, epoch, no_backprop=False):
+ """
+ A modular helper function that can be overridden in case
+ subclasses would like to modify training behavior for the
+ actor.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ no_backprop (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ info = OrderedDict()
+
+ # Actor loss (update with mixture of DDPG loss and BC loss)
+ s_batch = batch["obs"]
+ a_batch = batch["actions"]
+ goal_s_batch = batch["goal_obs"]
+
+ # lambda mixture weight is combination of hyperparameter (alpha) and Q-value normalization
+ actor_actions = self.nets["actor"](s_batch, goal_s_batch)
+ Q_values = self.nets["critic"][0](s_batch, actor_actions, goal_s_batch)
+ lam = self.algo_config.alpha / Q_values.abs().mean().detach()
+ actor_loss = -lam * Q_values.mean() + nn.MSELoss()(actor_actions, a_batch)
+ info["actor/loss"] = actor_loss
+
+ if not no_backprop:
+ actor_grad_norms = TorchUtils.backprop_for_loss(
+ net=self.nets["actor"],
+ optim=self.optimizers["actor"],
+ loss=actor_loss,
+ )
+ info["actor/grad_norms"] = actor_grad_norms
+
+ return info
+
+ def _get_target_values(self, next_states, goal_states, rewards, dones):
+ """
+ Helper function to get target values for training Q-function with TD-loss.
+
+ Args:
+ next_states (dict): batch of next observations
+ goal_states (dict): if not None, batch of goal observations
+ rewards (torch.Tensor): batch of rewards - should be shape (B, 1)
+ dones (torch.Tensor): batch of done signals - should be shape (B, 1)
+
+ Returns:
+ q_targets (torch.Tensor): target Q-values to use for TD loss
+ """
+
+ with torch.no_grad():
+ # get next actions via target actor and noise
+ next_target_actions = self.nets["actor_target"](next_states, goal_states)
+ noise = (
+ torch.randn_like(next_target_actions) * self.algo_config.actor.noise_std
+ ).clamp(-self.algo_config.actor.noise_clip, self.algo_config.actor.noise_clip)
+ next_actions = (next_target_actions + noise).clamp(-1.0, 1.0)
+
+ # TD3 trick to combine max and min over all Q-ensemble estimates into single target estimates
+ all_value_targets = self.nets["critic_target"][0](next_states, next_actions, goal_states).reshape(-1, 1)
+ max_value_targets = all_value_targets
+ min_value_targets = all_value_targets
+ for critic_target in self.nets["critic_target"][1:]:
+ all_value_targets = critic_target(next_states, next_actions, goal_states).reshape(-1, 1)
+ max_value_targets = torch.max(max_value_targets, all_value_targets)
+ min_value_targets = torch.min(min_value_targets, all_value_targets)
+ value_targets = self.algo_config.critic.ensemble.weight * min_value_targets + \
+ (1. - self.algo_config.critic.ensemble.weight) * max_value_targets
+ q_targets = rewards + dones * self.discount * value_targets
+
+ return q_targets
+
+ def _compute_critic_loss(self, critic, states, actions, goal_states, q_targets):
+ """
+ Helper function to compute loss between estimated Q-values and target Q-values.
+
+ Nearly the same as BCQ (return type slightly different).
+
+ Args:
+ critic (torch.nn.Module): critic network
+ states (dict): batch of observations
+ actions (torch.Tensor): batch of actions
+ goal_states (dict): if not None, batch of goal observations
+ q_targets (torch.Tensor): batch of target q-values for the TD loss
+
+ Returns:
+ critic_loss (torch.Tensor): critic loss
+ """
+ q_estimated = critic(states, actions, goal_states)
+ if self.algo_config.critic.use_huber:
+ critic_loss = nn.SmoothL1Loss()(q_estimated, q_targets)
+ else:
+ critic_loss = nn.MSELoss()(q_estimated, q_targets)
+ return critic_loss
+
+ def train_on_batch(self, batch, epoch, validate=False):
+ """
+ Training on a single batch of data.
+
+ Args:
+ batch (dict): dictionary with torch.Tensors sampled
+ from a data loader and filtered by @process_batch_for_training
+
+ epoch (int): epoch number - required by some Algos that need
+ to perform staged training and early stopping
+
+ validate (bool): if True, don't perform any learning updates.
+
+ Returns:
+ info (dict): dictionary of relevant inputs, outputs, and losses
+ that might be relevant for logging
+ """
+ with TorchUtils.maybe_no_grad(no_grad=validate):
+ info = PolicyAlgo.train_on_batch(self, batch, epoch, validate=validate)
+
+ # Critic training
+ no_critic_backprop = validate or (not self._check_epoch(net_name="critic", epoch=epoch))
+ with TorchUtils.maybe_no_grad(no_grad=no_critic_backprop):
+ critic_info = self._train_critic_on_batch(
+ batch=batch,
+ epoch=epoch,
+ no_backprop=no_critic_backprop,
+ )
+ info.update(critic_info)
+
+ # update actor and target networks at lower frequency
+ if not no_critic_backprop:
+ # update counter only on critic training gradient steps
+ self.actor_update_counter += 1
+ do_actor_update = (self.actor_update_counter % self.algo_config.actor.update_freq == 0)
+
+ # Actor training
+ no_actor_backprop = validate or (not self._check_epoch(net_name="actor", epoch=epoch))
+ no_actor_backprop = no_actor_backprop or (not do_actor_update)
+ with TorchUtils.maybe_no_grad(no_grad=no_actor_backprop):
+ actor_info = self._train_actor_on_batch(
+ batch=batch,
+ epoch=epoch,
+ no_backprop=no_actor_backprop,
+ )
+ info.update(actor_info)
+
+ if not no_actor_backprop:
+ # to match original implementation, only update target networks on
+ # actor gradient steps
+ with torch.no_grad():
+ # update the target critic networks
+ for critic_ind in range(len(self.nets["critic"])):
+ TorchUtils.soft_update(
+ source=self.nets["critic"][critic_ind],
+ target=self.nets["critic_target"][critic_ind],
+ tau=self.algo_config.target_tau,
+ )
+
+ # update target actor network
+ TorchUtils.soft_update(
+ source=self.nets["actor"],
+ target=self.nets["actor_target"],
+ tau=self.algo_config.target_tau,
+ )
+
+ return info
+
+ def log_info(self, info):
+ """
+ Process info dictionary from @train_on_batch to summarize
+ information to pass to tensorboard for logging.
+
+ Args:
+ info (dict): dictionary of info
+
+ Returns:
+ loss_log (dict): name -> summary statistic
+ """
+ loss_log = OrderedDict()
+
+ # record current optimizer learning rates
+ for k in self.optimizers:
+ keys = [k]
+ optims = [self.optimizers[k]]
+ if k == "critic":
+ # account for critic having one optimizer per ensemble member
+ keys = ["{}{}".format(k, critic_ind) for critic_ind in range(len(self.nets["critic"]))]
+ optims = self.optimizers[k]
+ for kp, optimizer in zip(keys, optims):
+ for i, param_group in enumerate(optimizer.param_groups):
+ loss_log["Optimizer/{}{}_lr".format(kp, i)] = param_group["lr"]
+
+ # extract relevant logs for critic, and actor
+ loss_log["Loss"] = 0.
+ for loss_logger in [self._log_critic_info, self._log_actor_info]:
+ this_log = loss_logger(info)
+ if "Loss" in this_log:
+ # manually merge total loss
+ loss_log["Loss"] += this_log["Loss"]
+ del this_log["Loss"]
+ loss_log.update(this_log)
+
+ return loss_log
+
+ def _log_critic_info(self, info):
+ """
+ Helper function to extract critic-relevant information for logging.
+ """
+ loss_log = OrderedDict()
+ if "done_masks" in info:
+ loss_log["Critic/Done_Mask_Percentage"] = 100. * torch.mean(info["done_masks"]).item()
+ if "critic/q_targets" in info:
+ loss_log["Critic/Q_Targets"] = info["critic/q_targets"].mean().item()
+ loss_log["Loss"] = 0.
+ for critic_ind in range(len(self.nets["critic"])):
+ loss_log["Critic/Critic{}_Loss".format(critic_ind + 1)] = info["critic/critic{}_loss".format(critic_ind + 1)].item()
+ if "critic/critic{}_grad_norms".format(critic_ind + 1) in info:
+ loss_log["Critic/Critic{}_Grad_Norms".format(critic_ind + 1)] = info["critic/critic{}_grad_norms".format(critic_ind + 1)]
+ loss_log["Loss"] += loss_log["Critic/Critic{}_Loss".format(critic_ind + 1)]
+ return loss_log
+
+ def _log_actor_info(self, info):
+ """
+ Helper function to extract actor-relevant information for logging.
+ """
+ loss_log = OrderedDict()
+ loss_log["Actor/Loss"] = info["actor/loss"].item()
+ if "actor/grad_norms" in info:
+ loss_log["Actor/Grad_Norms"] = info["actor/grad_norms"]
+ loss_log["Loss"] = loss_log["Actor/Loss"]
+ return loss_log
+
+ def set_train(self):
+ """
+ Prepare networks for evaluation. Update from super class to make sure
+ target networks stay in evaluation mode all the time.
+ """
+ self.nets.train()
+
+ # target networks always in eval
+ for critic_ind in range(len(self.nets["critic_target"])):
+ self.nets["critic_target"][critic_ind].eval()
+
+ self.nets["actor_target"].eval()
+
+ def on_epoch_end(self, epoch):
+ """
+ Called at the end of each epoch.
+ """
+
+ # LR scheduling updates
+ for lr_sc in self.lr_schedulers["critic"]:
+ if lr_sc is not None:
+ lr_sc.step()
+
+ if self.lr_schedulers["actor"] is not None:
+ self.lr_schedulers["actor"].step()
+
+ def get_action(self, obs_dict, goal_dict=None):
+ """
+ Get policy action outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ action (torch.Tensor): action tensor
+ """
+ assert not self.nets.training
+
+ return self.nets["actor"](obs_dict=obs_dict, goal_dict=goal_dict)
+
+ def get_state_value(self, obs_dict, goal_dict=None):
+ """
+ Get state value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ assert not self.nets.training
+
+ actions = self.nets["actor"](obs_dict=obs_dict, goal_dict=goal_dict)
+ return self.nets["critic"][0](obs_dict, actions, goal_dict)
+
+ def get_state_action_value(self, obs_dict, actions, goal_dict=None):
+ """
+ Get state-action value outputs.
+
+ Args:
+ obs_dict (dict): current observation
+ actions (torch.Tensor): action
+ goal_dict (dict): (optional) goal
+
+ Returns:
+ value (torch.Tensor): value tensor
+ """
+ assert not self.nets.training
+
+ return self.nets["critic"][0](obs_dict, actions, goal_dict)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/__init__.py b/phantom/submodules/phantom-robomimic/robomimic/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2cba6d3d89bcf9c73d7de8995dfef86ba9a8a94
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/__init__.py
@@ -0,0 +1,13 @@
+from robomimic.config.config import Config
+from robomimic.config.base_config import config_factory, get_all_registered_configs
+
+# note: these imports are needed to register these classes in the global config registry
+from robomimic.config.bc_config import BCConfig
+from robomimic.config.bcq_config import BCQConfig
+from robomimic.config.cql_config import CQLConfig
+from robomimic.config.iql_config import IQLConfig
+from robomimic.config.gl_config import GLConfig
+from robomimic.config.hbc_config import HBCConfig
+from robomimic.config.iris_config import IRISConfig
+from robomimic.config.td3_bc_config import TD3_BCConfig
+from robomimic.config.diffusion_policy_config import DiffusionPolicyConfig
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/base_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/base_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8321365f446a15161e857598009cc6c69e97a26b
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/base_config.py
@@ -0,0 +1,336 @@
+"""
+The base config class that is used for all algorithm configs in this repository.
+Subclasses get registered into a global dictionary, making it easy to instantiate
+the correct config class given the algorithm name.
+"""
+
+import six # preserve metaclass compatibility between python 2 and 3
+from copy import deepcopy
+
+import robomimic
+from robomimic.config.config import Config
+
+# global dictionary for remembering name - class mappings
+REGISTERED_CONFIGS = {}
+
+
+def get_all_registered_configs():
+ """
+ Give access to dictionary of all registered configs for external use.
+ """
+ return deepcopy(REGISTERED_CONFIGS)
+
+
+def config_factory(algo_name, dic=None):
+ """
+ Creates an instance of a config from the algo name. Optionally pass
+ a dictionary to instantiate the config from the dictionary.
+ """
+ if algo_name not in REGISTERED_CONFIGS:
+ raise Exception("Config for algo name {} not found. Make sure it is a registered config among: {}".format(
+ algo_name, ', '.join(REGISTERED_CONFIGS)))
+ return REGISTERED_CONFIGS[algo_name](dict_to_load=dic)
+
+
+class ConfigMeta(type):
+ """
+ Define a metaclass for constructing a config class.
+ It registers configs into the global registry.
+ """
+ def __new__(meta, name, bases, class_dict):
+ cls = super(ConfigMeta, meta).__new__(meta, name, bases, class_dict)
+ if cls.__name__ != "BaseConfig":
+ REGISTERED_CONFIGS[cls.ALGO_NAME] = cls
+ return cls
+
+
+@six.add_metaclass(ConfigMeta)
+class BaseConfig(Config):
+ def __init__(self, dict_to_load=None):
+ if dict_to_load is not None:
+ super(BaseConfig, self).__init__(dict_to_load)
+ return
+
+ super(BaseConfig, self).__init__()
+
+ # store algo name class property in the config (must be implemented by subclasses)
+ self.algo_name = type(self).ALGO_NAME
+
+ self.experiment_config()
+ self.train_config()
+ self.algo_config()
+ self.observation_config()
+ self.meta_config()
+
+ # After Config init, new keys cannot be added to the config, except under nested
+ # attributes that have called @do_not_lock_keys
+ self.lock_keys()
+
+ @property
+ @classmethod
+ def ALGO_NAME(cls):
+ # must be specified by subclasses
+ raise NotImplementedError
+
+ def experiment_config(self):
+ """
+ This function populates the `config.experiment` attribute of the config,
+ which has several experiment settings such as the name of the training run,
+ whether to do logging, whether to save models (and how often), whether to render
+ videos, and whether to do rollouts (and how often). This class has a default
+ implementation that usually doesn't need to be overriden.
+ """
+
+ self.experiment.name = "test" # name of experiment used to make log files
+ self.experiment.validate = False # whether to do validation or not
+ self.experiment.logging.terminal_output_to_txt = True # whether to log stdout to txt file
+ self.experiment.logging.log_tb = True # enable tensorboard logging
+ self.experiment.logging.log_wandb = False # enable wandb logging
+ self.experiment.logging.wandb_proj_name = "debug" # project name if using wandb
+
+
+ ## save config - if and when to save model checkpoints ##
+ self.experiment.save.enabled = True # whether model saving should be enabled or disabled
+ self.experiment.save.every_n_seconds = None # save model every n seconds (set to None to disable)
+ self.experiment.save.every_n_epochs = 50 # save model every n epochs (set to None to disable)
+ self.experiment.save.epochs = [] # save model on these specific epochs
+ self.experiment.save.on_best_validation = False # save models that achieve best validation score
+ self.experiment.save.on_best_rollout_return = False # save models that achieve best rollout return
+ self.experiment.save.on_best_rollout_success_rate = True # save models that achieve best success rate
+
+ # epoch definitions - if not None, set an epoch to be this many gradient steps, else the full dataset size will be used
+ self.experiment.epoch_every_n_steps = 100 # number of gradient steps in train epoch (None for full dataset pass)
+ self.experiment.validation_epoch_every_n_steps = 10 # number of gradient steps in valid epoch (None for full dataset pass)
+
+ # envs to evaluate model on (assuming rollouts are enabled), to override the metadata stored in dataset
+ self.experiment.env = None # no need to set this (unless you want to override)
+ self.experiment.additional_envs = None # additional environments that should get evaluated
+
+
+ ## rendering config ##
+ self.experiment.render = False # render on-screen or not
+ self.experiment.render_video = True # render evaluation rollouts to videos
+ self.experiment.keep_all_videos = False # save all videos, instead of only saving those for saved model checkpoints
+ self.experiment.video_skip = 5 # render video frame every n environment steps during rollout
+
+
+ ## evaluation rollout config ##
+ self.experiment.rollout.enabled = True # enable evaluation rollouts
+ self.experiment.rollout.n = 50 # number of rollouts per evaluation
+ self.experiment.rollout.horizon = 400 # maximum number of env steps per rollout
+ self.experiment.rollout.rate = 50 # do rollouts every @rate epochs
+ self.experiment.rollout.warmstart = 0 # number of epochs to wait before starting rollouts
+ self.experiment.rollout.terminate_on_success = True # end rollout early after task success
+
+ # for updating the evaluation env meta data
+ self.experiment.env_meta_update_dict = Config()
+ self.experiment.env_meta_update_dict.do_not_lock_keys()
+
+ def train_config(self):
+ """
+ This function populates the `config.train` attribute of the config, which
+ has several settings related to the training process, such as the dataset
+ to use for training, and how the data loader should load the data. This
+ class has a default implementation that usually doesn't need to be overriden.
+ """
+
+ # Path to hdf5 dataset to use for training
+ self.train.data = None
+
+ # Write all results to this directory. A new folder with the timestamp will be created
+ # in this directory, and it will contain three subfolders - "log", "models", and "videos".
+ # The "log" directory will contain tensorboard and stdout txt logs. The "models" directory
+ # will contain saved model checkpoints. The "videos" directory contains evaluation rollout
+ # videos.
+ self.train.output_dir = "../{}_trained_models".format(self.algo_name)
+
+
+ ## dataset loader config ##
+
+ # num workers for loading data - generally set to 0 for low-dim datasets, and 2 for image datasets
+ self.train.num_data_workers = 0
+
+ # One of ["all", "low_dim", or None]. Set to "all" to cache entire hdf5 in memory - this is
+ # by far the fastest for data loading. Set to "low_dim" to cache all non-image data. Set
+ # to None to use no caching - in this case, every batch sample is retrieved via file i/o.
+ # You should almost never set this to None, even for large image datasets.
+ self.train.hdf5_cache_mode = "all"
+
+ # used for parallel data loading
+ self.train.hdf5_use_swmr = True
+
+ # whether to load "next_obs" group from hdf5 - only needed for batch / offline RL algorithms
+ self.train.hdf5_load_next_obs = True
+
+ # if true, normalize observations at train and test time, using the global mean and standard deviation
+ # of each observation in each dimension, computed across the training set. See SequenceDataset.normalize_obs
+ # in utils/dataset.py for more information.
+ self.train.hdf5_normalize_obs = False
+
+ # if provided, use the list of demo keys under the hdf5 group "mask/@hdf5_filter_key" for training, instead
+ # of the full dataset. This provides a convenient way to train on only a subset of the trajectories in a dataset.
+ self.train.hdf5_filter_key = None
+
+ # if provided, use the list of demo keys under the hdf5 group "mask/@hdf5_validation_filter_key" for validation.
+ # Must be provided if @experiment.validate is True.
+ self.train.hdf5_validation_filter_key = None
+
+ # length of experience sequence to fetch from the dataset
+ # and whether to pad the beginning / end of the sequence at boundaries of trajectory in dataset
+ self.train.seq_length = 1
+ self.train.pad_seq_length = True
+ self.train.frame_stack = 1
+ self.train.pad_frame_stack = True
+
+ # keys from hdf5 to load into each batch, besides "obs" and "next_obs". If algorithms
+ # require additional keys from each trajectory in the hdf5, they should be specified here.
+ self.train.dataset_keys = (
+ "actions",
+ "rewards",
+ "dones",
+ )
+
+ self.train.action_keys = ["actions"]
+
+ # specifing each action keys to load and their corresponding normalization/conversion requirement
+ # e.g. for dataset keys "action/eef_pos" and "action/eef_rot"
+ # the desired value of self.train.action_config is:
+ # {
+ # "action/eef_pos": {
+ # "normalization": "min_max",
+ # "rot_conversion: None
+ # },
+ # "action/eef_rot": {
+ # "normalization": None,
+ # "rot_conversion: "axis_angle_to_6d"
+ # }
+ # }
+ # self.train.action_config.actions.normalization = None # "min_max"
+ # self.train.action_config.actions.rot_conversion = None # "axis_angle_to_6d"
+ self.train.action_config = {}
+ # self.train.action_config.do_not_lock_keys()
+
+ # one of [None, "last"] - set to "last" to include goal observations in each batch
+ self.train.goal_mode = None
+
+
+ ## learning config ##
+ self.train.cuda = True # use GPU or not
+ self.train.batch_size = 100 # batch size
+ self.train.num_epochs = 2000 # number of training epochs
+ self.train.seed = 1 # seed for training (for reproducibility)
+
+ self.train.data_format = "robomimic" # either "robomimic" or "r2d2"
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here. This function should be
+ implemented by every subclass.
+ """
+ pass
+
+ def observation_config(self):
+ """
+ This function populates the `config.observation` attribute of the config, and is given
+ to the `Algo` subclass (see `algo/algo.py`) for each algorithm through the `obs_config`
+ argument to the constructor. This portion of the config is used to specify what
+ observation modalities should be used by the networks for training, and how the
+ observation modalities should be encoded by the networks. While this class has a
+ default implementation that usually doesn't need to be overriden, certain algorithm
+ configs may choose to, in order to have seperate configs for different networks
+ in the algorithm.
+ """
+
+ # observation modalities
+ self.observation.modalities.obs.low_dim = [ # specify low-dim observations for agent
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object",
+ ]
+ self.observation.modalities.obs.rgb = [] # specify rgb image observations for agent
+ self.observation.modalities.obs.depth = []
+ self.observation.modalities.obs.scan = []
+ self.observation.modalities.goal.low_dim = [] # specify low-dim goal observations to condition agent on
+ self.observation.modalities.goal.rgb = [] # specify rgb image goal observations to condition agent on
+ self.observation.modalities.goal.depth = []
+ self.observation.modalities.goal.scan = []
+ self.observation.modalities.obs.do_not_lock_keys()
+ self.observation.modalities.goal.do_not_lock_keys()
+
+ # observation encoder architectures (per obs modality)
+ # This applies to all networks that take observation dicts as input
+
+ # =============== Low Dim default encoder (no encoder) ===============
+ self.observation.encoder.low_dim.core_class = None
+ self.observation.encoder.low_dim.core_kwargs = Config() # No kwargs by default
+ self.observation.encoder.low_dim.core_kwargs.do_not_lock_keys()
+
+ # Low Dim: Obs Randomizer settings
+ self.observation.encoder.low_dim.obs_randomizer_class = None
+ self.observation.encoder.low_dim.obs_randomizer_kwargs = Config() # No kwargs by default
+ self.observation.encoder.low_dim.obs_randomizer_kwargs.do_not_lock_keys()
+
+ # =============== RGB default encoder (ResNet backbone + linear layer output) ===============
+ self.observation.encoder.rgb.core_class = "VisualCore" # Default VisualCore class combines backbone (like ResNet-18) with pooling operation (like spatial softmax)
+ self.observation.encoder.rgb.core_kwargs = Config() # See models/obs_core.py for important kwargs to set and defaults used
+ self.observation.encoder.rgb.core_kwargs.do_not_lock_keys()
+
+ # RGB: Obs Randomizer settings
+ self.observation.encoder.rgb.obs_randomizer_class = None # Can set to 'CropRandomizer' to use crop randomization
+ self.observation.encoder.rgb.obs_randomizer_kwargs = Config() # See models/obs_core.py for important kwargs to set and defaults used
+ self.observation.encoder.rgb.obs_randomizer_kwargs.do_not_lock_keys()
+
+ # Allow for other custom modalities to be specified
+ self.observation.encoder.do_not_lock_keys()
+
+ # =============== Depth default encoder (same as rgb) ===============
+ self.observation.encoder.depth = deepcopy(self.observation.encoder.rgb)
+
+ # =============== Scan default encoder (Conv1d backbone + linear layer output) ===============
+ self.observation.encoder.scan = deepcopy(self.observation.encoder.rgb)
+
+ # Scan: Modify the core class + kwargs, otherwise, is same as rgb encoder
+ self.observation.encoder.scan.core_class = "ScanCore" # Default ScanCore class uses Conv1D to process this modality
+ self.observation.encoder.scan.core_kwargs = Config() # See models/obs_core.py for important kwargs to set and defaults used
+ self.observation.encoder.scan.core_kwargs.do_not_lock_keys()
+
+ def meta_config(self):
+ """
+ This function populates the `config.meta` attribute of the config. This portion of the config
+ is used to specify job information primarily for hyperparameter sweeps.
+ It contains hyperparameter keys and values, which are populated automatically
+ by the hyperparameter config generator (see `utils/hyperparam_utils.py`).
+ These values are read by the wandb logger (see `utils/log_utils.py`) to set job tags.
+ """
+
+ self.meta.hp_base_config_file = None # base config file in hyperparam sweep
+ self.meta.hp_keys = [] # relevant keys (swept) in hyperparam sweep
+ self.meta.hp_values = [] # values corresponding to keys in hyperparam sweep
+
+ @property
+ def use_goals(self):
+ # whether the agent is goal-conditioned
+ return len([obs_key for modality in self.observation.modalities.goal.values() for obs_key in modality]) > 0
+
+ @property
+ def all_obs_keys(self):
+ """
+ This grabs the union of observation keys over all modalities (e.g.: low_dim, rgb, depth, etc.) and over all
+ modality groups (e.g: obs, goal, subgoal, etc...)
+
+ Returns:
+ n-array: all observation keys used for this model
+ """
+ # pool all modalities
+ return sorted(tuple(set([
+ obs_key for group in [
+ self.observation.modalities.obs.values(),
+ self.observation.modalities.goal.values()
+ ]
+ for modality in group
+ for obs_key in modality
+ ])))
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/bc_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f701c685e9deb7755729d446ff272ba52a5ccc1
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/bc_config.py
@@ -0,0 +1,106 @@
+"""
+Config for BC algorithm.
+"""
+
+from robomimic.config.base_config import BaseConfig
+
+
+class BCConfig(BaseConfig):
+ ALGO_NAME = "bc"
+
+ def train_config(self):
+ """
+ BC algorithms don't need "next_obs" from hdf5 - so save on storage and compute by disabling it.
+ """
+ super(BCConfig, self).train_config()
+ self.train.hdf5_load_next_obs = False
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here.
+ """
+
+ # optimization parameters
+ self.algo.optim_params.policy.optimizer_type = "adam"
+ self.algo.optim_params.policy.learning_rate.initial = 1e-4 # policy learning rate
+ self.algo.optim_params.policy.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.policy.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.policy.learning_rate.scheduler_type = "multistep" # learning rate scheduler ("multistep", "linear", etc)
+ self.algo.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength
+
+ # loss weights
+ self.algo.loss.l2_weight = 1.0 # L2 loss weight
+ self.algo.loss.l1_weight = 0.0 # L1 loss weight
+ self.algo.loss.cos_weight = 0.0 # cosine loss weight
+
+ # MLP network architecture (layers after observation encoder and RNN, if present)
+ self.algo.actor_layer_dims = (1024, 1024)
+
+ # stochastic Gaussian policy settings
+ self.algo.gaussian.enabled = False # whether to train a Gaussian policy
+ self.algo.gaussian.fixed_std = False # whether to train std output or keep it constant
+ self.algo.gaussian.init_std = 0.1 # initial standard deviation (or constant)
+ self.algo.gaussian.min_std = 0.01 # minimum std output from network
+ self.algo.gaussian.std_activation = "softplus" # activation to use for std output from policy net
+ self.algo.gaussian.low_noise_eval = True # low-std at test-time
+
+ # stochastic GMM policy settings
+ self.algo.gmm.enabled = False # whether to train a GMM policy
+ self.algo.gmm.num_modes = 5 # number of GMM modes
+ self.algo.gmm.min_std = 0.0001 # minimum std output from network
+ self.algo.gmm.std_activation = "softplus" # activation to use for std output from policy net
+ self.algo.gmm.low_noise_eval = True # low-std at test-time
+
+ # stochastic VAE policy settings
+ self.algo.vae.enabled = False # whether to train a VAE policy
+ self.algo.vae.latent_dim = 14 # VAE latent dimnsion - set to twice the dimensionality of action space
+ self.algo.vae.latent_clip = None # clip latent space when decoding (set to None to disable)
+ self.algo.vae.kl_weight = 1. # beta-VAE weight to scale KL loss relative to reconstruction loss in ELBO
+
+ # VAE decoder settings
+ self.algo.vae.decoder.is_conditioned = True # whether decoder should condition on observation
+ self.algo.vae.decoder.reconstruction_sum_across_elements = False # sum instead of mean for reconstruction loss
+
+ # VAE prior settings
+ self.algo.vae.prior.learn = False # learn Gaussian / GMM prior instead of N(0, 1)
+ self.algo.vae.prior.is_conditioned = False # whether to condition prior on observations
+ self.algo.vae.prior.use_gmm = False # whether to use GMM prior
+ self.algo.vae.prior.gmm_num_modes = 10 # number of GMM modes
+ self.algo.vae.prior.gmm_learn_weights = False # whether to learn GMM weights
+ self.algo.vae.prior.use_categorical = False # whether to use categorical prior
+ self.algo.vae.prior.categorical_dim = 10 # the number of categorical classes for each latent dimension
+ self.algo.vae.prior.categorical_gumbel_softmax_hard = False # use hard selection in forward pass
+ self.algo.vae.prior.categorical_init_temp = 1.0 # initial gumbel-softmax temp
+ self.algo.vae.prior.categorical_temp_anneal_step = 0.001 # linear temp annealing rate
+ self.algo.vae.prior.categorical_min_temp = 0.3 # lowest gumbel-softmax temp
+
+ self.algo.vae.encoder_layer_dims = (300, 400) # encoder MLP layer dimensions
+ self.algo.vae.decoder_layer_dims = (300, 400) # decoder MLP layer dimensions
+ self.algo.vae.prior_layer_dims = (300, 400) # prior MLP layer dimensions (if learning conditioned prior)
+
+ # RNN policy settings
+ self.algo.rnn.enabled = False # whether to train RNN policy
+ self.algo.rnn.horizon = 10 # unroll length for RNN - should usually match train.seq_length
+ self.algo.rnn.hidden_dim = 400 # hidden dimension size
+ self.algo.rnn.rnn_type = "LSTM" # rnn type - one of "LSTM" or "GRU"
+ self.algo.rnn.num_layers = 2 # number of RNN layers that are stacked
+ self.algo.rnn.open_loop = False # if True, action predictions are only based on a single observation (not sequence)
+ self.algo.rnn.kwargs.bidirectional = False # rnn kwargs
+ self.algo.rnn.kwargs.do_not_lock_keys()
+
+ # Transformer policy settings
+ self.algo.transformer.enabled = False # whether to train transformer policy
+ self.algo.transformer.context_length = 10 # length of (s, a) seqeunces to feed to transformer - should usually match train.frame_stack
+ self.algo.transformer.embed_dim = 512 # dimension for embeddings used by transformer
+ self.algo.transformer.num_layers = 6 # number of transformer blocks to stack
+ self.algo.transformer.num_heads = 8 # number of attention heads for each transformer block (should divide embed_dim evenly)
+ self.algo.transformer.emb_dropout = 0.1 # dropout probability for embedding inputs in transformer
+ self.algo.transformer.attn_dropout = 0.1 # dropout probability for attention outputs for each transformer block
+ self.algo.transformer.block_output_dropout = 0.1 # dropout probability for final outputs for each transformer block
+ self.algo.transformer.sinusoidal_embedding = False # if True, use standard positional encodings (sin/cos)
+ self.algo.transformer.activation = "gelu" # activation function for MLP in Transformer Block
+ self.algo.transformer.supervise_all_steps = False # if true, supervise all intermediate actions, otherwise only final one
+ self.algo.transformer.nn_parameter_for_timesteps = True # if true, use nn.Parameter otherwise use nn.Embedding
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/bcq_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/bcq_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e28f5ba5668aa2e5e6d9ca2187953a4e05b56a7d
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/bcq_config.py
@@ -0,0 +1,83 @@
+"""
+Config for BCQ algorithm.
+"""
+
+from robomimic.config.base_config import BaseConfig
+from robomimic.config.bc_config import BCConfig
+
+
+class BCQConfig(BaseConfig):
+ ALGO_NAME = "bcq"
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here.
+ """
+
+ # optimization parameters
+ self.algo.optim_params.critic.learning_rate.initial = 1e-3 # critic learning rate
+ self.algo.optim_params.critic.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.critic.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.critic.regularization.L2 = 0.00 # L2 regularization strength
+ self.algo.optim_params.critic.start_epoch = -1 # number of epochs before starting critic training (-1 means start right away)
+ self.algo.optim_params.critic.end_epoch = -1 # number of epochs before ending critic training (-1 means start right away)
+
+ self.algo.optim_params.action_sampler.learning_rate.initial = 1e-3 # action sampler learning rate
+ self.algo.optim_params.action_sampler.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.action_sampler.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.action_sampler.regularization.L2 = 0.00 # L2 regularization strength
+ self.algo.optim_params.action_sampler.start_epoch = -1 # number of epochs before starting action sampler training (-1 means start right away)
+ self.algo.optim_params.action_sampler.end_epoch = -1 # number of epochs before ending action sampler training (-1 means start right away)
+
+ self.algo.optim_params.actor.learning_rate.initial = 1e-3 # actor learning rate
+ self.algo.optim_params.actor.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.actor.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.actor.regularization.L2 = 0.00 # L2 regularization strength
+ self.algo.optim_params.actor.start_epoch = -1 # number of epochs before starting actor training (-1 means start right away)
+ self.algo.optim_params.actor.end_epoch = -1 # number of epochs before ending actor training (-1 means start right away)
+
+ # target network related parameters
+ self.algo.discount = 0.99 # discount factor to use
+ self.algo.n_step = 1 # for using n-step returns in TD-updates
+ self.algo.target_tau = 0.005 # update rate for target networks
+ self.algo.infinite_horizon = False # if True, scale terminal rewards by 1 / (1 - discount) to treat as infinite horizon
+
+ # ================== Critic Network Config ===================
+ self.algo.critic.use_huber = False # Huber Loss instead of L2 for critic
+ self.algo.critic.max_gradient_norm = None # L2 gradient clipping for critic (None to use no clipping)
+ self.algo.critic.value_bounds = None # optional 2-tuple to ensure lower and upper bound on value estimates
+ self.algo.critic.num_action_samples = 10 # number of actions to sample per training batch to get target critic value
+ self.algo.critic.num_action_samples_rollout = 100 # number of actions to sample per environment step
+
+ # critic ensemble parameters (TD3 trick)
+ self.algo.critic.ensemble.n = 2 # number of Q networks in the ensemble
+ self.algo.critic.ensemble.weight = 0.75 # weighting for mixing min and max for target Q value
+
+ # distributional critic
+ self.algo.critic.distributional.enabled = False # train distributional critic (C51)
+ self.algo.critic.distributional.num_atoms = 51 # number of values in categorical distribution
+
+ self.algo.critic.layer_dims = (300, 400) # size of critic MLP
+
+ # ================== Action Sampler Config ===================
+ self.algo.action_sampler = BCConfig().algo
+ # use VAE by default
+ self.algo.action_sampler.vae.enabled = True
+ # remove unused parts of BCConfig algo config
+ del self.algo.action_sampler.optim_params # since action sampler optim params specified at top-level
+ del self.algo.action_sampler.loss
+ del self.algo.action_sampler.gaussian
+ del self.algo.action_sampler.rnn
+ del self.algo.action_sampler.transformer
+
+ # Number of epochs before freezing encoder (-1 for no freezing). Only applies to cVAE-based action samplers.
+ with self.algo.action_sampler.unlocked():
+ self.algo.action_sampler.freeze_encoder_epoch = -1
+
+ # ================== Actor Network Config ===================
+ self.algo.actor.enabled = False # whether to use the actor perturbation network
+ self.algo.actor.perturbation_scale = 0.05 # size of learned action perturbations
+ self.algo.actor.layer_dims = (300, 400) # size of actor MLP
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/config.py b/phantom/submodules/phantom-robomimic/robomimic/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..74da6535b385f91aa5c34e20af731ba2e3d06ecb
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/config.py
@@ -0,0 +1,322 @@
+"""
+Basic config class - provides a convenient way to work with nested
+dictionaries (by exposing keys as attributes) and to save / load from jsons.
+
+Based on addict: https://github.com/mewwts/addict
+"""
+
+import json
+import copy
+import contextlib
+from copy import deepcopy
+
+
+class Config(dict):
+
+ def __init__(__self, *args, **kwargs):
+ object.__setattr__(__self, '__key_locked', False) # disallow adding new keys
+ object.__setattr__(__self, '__all_locked', False) # disallow both key and value update
+ object.__setattr__(__self, '__do_not_lock_keys', False) # cannot be key-locked
+ object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
+ object.__setattr__(__self, '__key', kwargs.pop('__key', None))
+ for arg in args:
+ if not arg:
+ continue
+ elif isinstance(arg, dict):
+ for key, val in arg.items():
+ __self[key] = __self._hook(val)
+ elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
+ __self[arg[0]] = __self._hook(arg[1])
+ else:
+ for key, val in iter(arg):
+ __self[key] = __self._hook(val)
+
+ for key, val in kwargs.items():
+ __self[key] = __self._hook(val)
+
+ def lock(self):
+ """
+ Lock the config. Afterwards, new keys cannot be added to the
+ config, and the values of existing keys cannot be modified.
+ """
+ object.__setattr__(self, '__all_locked', True)
+ if self.key_lockable:
+ object.__setattr__(self, '__key_locked', True)
+
+ for k in self:
+ if isinstance(self[k], Config):
+ self[k].lock()
+
+ def unlock(self):
+ """
+ Unlock the config. Afterwards, new keys can be added to the
+ config, and the values of existing keys can be modified.
+ """
+ object.__setattr__(self, '__all_locked', False)
+ object.__setattr__(self, '__key_locked', False)
+
+ for k in self:
+ if isinstance(self[k], Config):
+ self[k].unlock()
+
+ def _get_lock_state_recursive(self):
+ """
+ Internal helper function to get the lock state of all sub-configs recursively.
+ """
+ lock_state = {"__all_locked": self.is_locked, "__key_locked": self.is_key_locked}
+ for k in self:
+ if isinstance(self[k], Config):
+ assert k not in ["__all_locked", "__key_locked"]
+ lock_state[k] = self[k]._get_lock_state_recursive()
+ return lock_state
+
+ def _set_lock_state_recursive(self, lock_state):
+ """
+ Internal helper function to set the lock state of all sub-configs recursively.
+ """
+ lock_state = deepcopy(lock_state)
+ object.__setattr__(self, '__all_locked', lock_state.pop("__all_locked"))
+ object.__setattr__(self, '__key_locked', lock_state.pop("__key_locked"))
+ for k in lock_state:
+ if isinstance(self[k], Config):
+ self[k]._set_lock_state_recursive(lock_state[k])
+
+ def _get_lock_state(self):
+ """
+ Retrieves the lock state of this config.
+
+ Returns:
+ lock_state (dict): a dictionary with an "all_locked" key that is True
+ if both key and value updates are locked and False otherwise, and
+ a "key_locked" key that is True if only key updates are locked (value
+ updates still allowed) and False otherwise
+ """
+ return {
+ "all_locked": self.is_locked,
+ "key_locked": self.is_key_locked
+ }
+
+ def _set_lock_state(self, lock_state):
+ """
+ Sets the lock state for this config.
+
+ Args:
+ lock_state (dict): a dictionary with an "all_locked" key that is True
+ if both key and value updates should be locked and False otherwise, and
+ a "key_locked" key that is True if only key updates should be locked (value
+ updates still allowed) and False otherwise
+ """
+ if lock_state["all_locked"]:
+ self.lock()
+ if lock_state["key_locked"]:
+ self.lock_keys()
+
+ @contextlib.contextmanager
+ def unlocked(self):
+ """
+ A context scope for modifying a Config object. Within the scope,
+ both keys and values can be updated. Upon leaving the scope,
+ the initial level of locking is restored.
+ """
+ lock_state = self._get_lock_state()
+ self.unlock()
+ yield
+ self._set_lock_state(lock_state)
+
+ @contextlib.contextmanager
+ def values_unlocked(self):
+ """
+ A context scope for modifying a Config object. Within the scope,
+ only values can be updated (new keys cannot be created). Upon
+ leaving the scope, the initial level of locking is restored.
+ """
+ lock_state = self._get_lock_state()
+ self.unlock()
+ self.lock_keys()
+ yield
+ self._set_lock_state(lock_state)
+
+ def lock_keys(self):
+ """
+ Lock this config so that new keys cannot be added.
+ """
+ if not self.key_lockable:
+ return
+ object.__setattr__(self, '__key_locked', True)
+ for k in self:
+ if isinstance(self[k], Config):
+ self[k].lock_keys()
+
+ def unlock_keys(self):
+ """
+ Unlock this config so that new keys can be added.
+ """
+ object.__setattr__(self, '__key_locked', False)
+ for k in self:
+ if isinstance(self[k], Config):
+ self[k].unlock_keys()
+
+ @property
+ def is_locked(self):
+ """
+ Returns True if the config is locked (no key or value updates allowed).
+ """
+ return object.__getattribute__(self, '__all_locked')
+
+ @property
+ def is_key_locked(self):
+ """
+ Returns True if the config is key-locked (no key updates allowed).
+ """
+ return object.__getattribute__(self, '__key_locked')
+
+ def do_not_lock_keys(self):
+ """
+ Calling this function on this config indicates that key updates should be
+ allowed even when this config is key-locked (but not when it is completely
+ locked). This is convenient for attributes that contain kwargs, where there
+ might be a variable type and number of arguments contained in the sub-config.
+ """
+ object.__setattr__(self, '__do_not_lock_keys', True)
+
+ @property
+ def key_lockable(self):
+ """
+ Returns true if this config is key-lockable (new keys cannot be inserted in a
+ key-locked lock level).
+ """
+ return not object.__getattribute__(self, '__do_not_lock_keys')
+
+ def __setattr__(self, name, value):
+ if self.is_locked:
+ raise RuntimeError("This config has been locked - cannot set attribute '{}' to {}".format(name, value))
+
+ if hasattr(Config, name):
+ raise AttributeError("'Dict' object attribute "
+ "'{0}' is read-only".format(name))
+ elif not hasattr(self, name) and self.is_key_locked:
+ raise RuntimeError("This config is key-locked - cannot add key '{}'".format(name))
+ else:
+ self[name] = value
+
+ def __setitem__(self, name, value):
+ super(Config, self).__setitem__(name, value)
+ p = object.__getattribute__(self, '__parent')
+ key = object.__getattribute__(self, '__key')
+ if p is not None:
+ p[key] = self
+
+ def __add__(self, other):
+ if not self.keys():
+ return other
+ else:
+ self_type = type(self).__name__
+ other_type = type(other).__name__
+ msg = "unsupported operand type(s) for +: '{}' and '{}'"
+ raise TypeError(msg.format(self_type, other_type))
+
+ @classmethod
+ def _hook(cls, item):
+ if isinstance(item, dict):
+ # We return Config instance instead of cls instance to ensure all sub-configs are not a top-level class
+ return Config(item)
+ elif isinstance(item, (list, tuple)):
+ return type(item)(Config._hook(elem) for elem in item)
+ return item
+
+ def __getattr__(self, item):
+ return self.__getitem__(item)
+
+ def __repr__(self):
+ json_string = json.dumps(self.to_dict(), indent=4)
+ return json_string
+
+ def __getitem__(self, name):
+ if name not in self:
+ if object.__getattribute__(self, '__all_locked') or object.__getattribute__(self, '__key_locked'):
+ raise RuntimeError("This config has been locked and '{}' is not in this config".format(name))
+ return Config(__parent=self, __key=name)
+ return super(Config, self).__getitem__(name)
+
+ def __delattr__(self, name):
+ del self[name]
+
+ def to_dict(self):
+ base = {}
+ for key, value in self.items():
+ if isinstance(value, type(self)):
+ base[key] = value.to_dict()
+ elif isinstance(value, (list, tuple)):
+ base[key] = type(value)(
+ item.to_dict() if isinstance(item, type(self)) else
+ item for item in value)
+ else:
+ base[key] = value
+ return base
+
+ def copy(self):
+ return copy.copy(self)
+
+ def deepcopy(self):
+ return copy.deepcopy(self)
+
+ def __deepcopy__(self, memo):
+ other = self.__class__()
+ memo[id(self)] = other
+ for key, value in self.items():
+ other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
+ return other
+
+ def update(self, *args, **kwargs):
+ """
+ Update this config using another config or nested dictionary.
+ """
+ if self.is_locked:
+ raise RuntimeError('Cannot update - this config has been locked')
+ other = {}
+ if args:
+ if len(args) > 1:
+ raise TypeError()
+ other.update(args[0])
+ other.update(kwargs)
+ for k, v in other.items():
+ if self.is_key_locked and k not in self:
+ raise RuntimeError("Cannot update - this config has been key-locked and key '{}' does not exist".format(k))
+ if (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
+ self[k] = v
+ else:
+ self[k].update(v)
+
+ def __getnewargs__(self):
+ return tuple(self.items())
+
+ def __getstate__(self):
+ return self
+
+ def __setstate__(self, state):
+ self.update(state)
+
+ def setdefault(self, key, default=None):
+ if key in self:
+ return self[key]
+ else:
+ self[key] = default
+ return default
+
+ def dump(self, filename=None):
+ """
+ Dumps the config to a json.
+
+ Args:
+ filename (str): if not None, save to json file.
+
+ Returns:
+ json_string (str): json string representation of
+ this config
+ """
+ json_string = json.dumps(self.to_dict(), indent=4)
+ if filename is not None:
+ f = open(filename, "w")
+ f.write(json_string)
+ f.close()
+ return json_string
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/cql_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..26fea048fe49d2d2d03f888eedab6754f37c8dcc
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/cql_config.py
@@ -0,0 +1,82 @@
+"""
+Config for CQL algorithm.
+"""
+
+from robomimic.config.base_config import BaseConfig
+
+
+class CQLConfig(BaseConfig):
+ ALGO_NAME = "cql"
+
+ def train_config(self):
+ """
+ Update from superclass to change default batch size.
+ """
+ super(CQLConfig, self).train_config()
+
+ # increase batch size to 1024 (found to work better for most manipulation experiments)
+ self.train.batch_size = 1024
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here.
+ """
+
+ # optimization parameters
+ self.algo.optim_params.critic.learning_rate.initial = 1e-3 # critic learning rate
+ self.algo.optim_params.critic.learning_rate.decay_factor = 0.0 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.critic.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.critic.regularization.L2 = 0.00 # L2 regularization strength
+
+ self.algo.optim_params.actor.learning_rate.initial = 3e-4 # actor learning rate
+ self.algo.optim_params.actor.learning_rate.decay_factor = 0.0 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.actor.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.actor.regularization.L2 = 0.00 # L2 regularization strength
+
+ # target network related parameters
+ self.algo.discount = 0.99 # discount factor to use
+ self.algo.n_step = 1 # for using n-step returns in TD-updates
+ self.algo.target_tau = 0.005 # update rate for target networks
+
+ # ================== Actor Network Config ===================
+ self.algo.actor.bc_start_steps = 0 # uses BC policy loss for first n-training steps
+ self.algo.actor.target_entropy = "default" # None is fixed entropy, otherwise is automatically tuned to match target. Can specify "default" as well for default tuning target
+ self.algo.actor.max_gradient_norm = None # L2 gradient clipping for actor
+
+ # Actor network settings
+ self.algo.actor.net.type = "gaussian" # Options are currently only "gaussian" (no support for GMM yet)
+
+ # Actor network settings - shared
+ self.algo.actor.net.common.std_activation = "exp" # Activation to use for std output from policy net
+ self.algo.actor.net.common.use_tanh = True # Whether to use tanh at output of actor network
+ self.algo.actor.net.common.low_noise_eval = True # Whether to use deterministic action sampling at eval stage
+
+ # Actor network settings - gaussian
+ self.algo.actor.net.gaussian.init_last_fc_weight = 0.001 # If set, will override the initialization of the final fc layer to be uniformly sampled limited by this value
+ self.algo.actor.net.gaussian.init_std = 0.3 # Relative scaling factor for std from policy net
+ self.algo.actor.net.gaussian.fixed_std = False # Whether to learn std dev or not
+
+ self.algo.actor.layer_dims = (300, 400) # actor MLP layer dimensions
+
+ # ================== Critic Network Config ===================
+ self.algo.critic.use_huber = False # Huber Loss instead of L2 for critic
+ self.algo.critic.max_gradient_norm = None # L2 gradient clipping for critic (None to use no clipping)
+
+ self.algo.critic.value_bounds = None # optional 2-tuple to ensure lower and upper bound on value estimates
+
+ self.algo.critic.num_action_samples = 1 # number of actions to sample per training batch to get target critic value; use maximum Q value from n random sampled actions when doing TD error backup
+
+ # cql settings for critic
+ self.algo.critic.cql_weight = 1.0 # weighting for cql component of critic loss (only used if target_q_gap is < 0 or None)
+ self.algo.critic.deterministic_backup = True # if not set, subtract weighted logprob of action when doing backup
+ self.algo.critic.min_q_weight = 1.0 # min q weight (scaling factor) to apply
+ self.algo.critic.target_q_gap = 5.0 # if set, sets the diff threshold at which Q-values will be penalized more (note: this overrides cql weight above!) Use None or a negative value if not set
+ self.algo.critic.num_random_actions = 10 # Number of random actions to sample when calculating CQL loss
+
+ # critic ensemble parameters (TD3 trick)
+ self.algo.critic.ensemble.n = 2 # number of Q networks in the ensemble
+
+ self.algo.critic.layer_dims = (300, 400) # critic MLP layer dimensions
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/default_templates/bc_transformer.json b/phantom/submodules/phantom-robomimic/robomimic/config/default_templates/bc_transformer.json
new file mode 100644
index 0000000000000000000000000000000000000000..ed59f175b532c1cd61e8c4efefba1d985e8eaa31
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/default_templates/bc_transformer.json
@@ -0,0 +1,171 @@
+{
+ "algo_name": "bc",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../bc_transformer_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "all",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": false,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "seq_length": 1,
+ "pad_seq_length": true,
+ "frame_stack": 10,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 100,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "optim_params": {
+ "policy": {
+ "optimizer_type": "adamw",
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.1,
+ "epoch_schedule": [100],
+ "scheduler_type": "linear"
+ },
+ "regularization": {
+ "L2": 0.01
+ }
+ }
+ },
+ "loss": {
+ "l2_weight": 1.0,
+ "l1_weight": 0.0,
+ "cos_weight": 0.0
+ },
+ "actor_layer_dims": [],
+ "gaussian": {
+ "enabled": false
+ },
+ "gmm": {
+ "enabled": true,
+ "num_modes": 5,
+ "min_std": 0.0001,
+ "std_activation": "softplus",
+ "low_noise_eval": true
+ },
+ "vae": {
+ "enabled": false
+ },
+ "rnn": {
+ "enabled": false
+ },
+ "transformer": {
+ "enabled": true,
+ "supervise_all_steps": false,
+ "num_layers": 6,
+ "embed_dim": 512,
+ "num_heads": 8
+ }
+ },
+ "observation": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {
+ "feature_dimension": 64,
+ "backbone_class": "ResNet18Conv",
+ "backbone_kwargs": {
+ "pretrained": false,
+ "input_coord_conv": false
+ },
+ "pool_class": "SpatialSoftmax",
+ "pool_kwargs": {
+ "num_kp": 32,
+ "learnable_temperature": false,
+ "temperature": 1.0,
+ "noise_std": 0.0
+ }
+ },
+ "obs_randomizer_class": "CropRandomizer",
+ "obs_randomizer_kwargs": {
+ "crop_height": 76,
+ "crop_width": 76,
+ "num_crops": 1,
+ "pos_enc": false
+ }
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ }
+}
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/diffusion_policy_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/diffusion_policy_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8662a107d53c2cdae95454caa521a677326a01d8
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/diffusion_policy_config.py
@@ -0,0 +1,57 @@
+"""
+Config for Diffusion Policy algorithm.
+"""
+
+from robomimic.config.base_config import BaseConfig
+
+class DiffusionPolicyConfig(BaseConfig):
+ ALGO_NAME = "diffusion_policy"
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here.
+ """
+
+ # optimization parameters
+ self.algo.optim_params.policy.learning_rate.initial = 1e-4 # policy learning rate
+ self.algo.optim_params.policy.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.policy.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength
+
+ # horizon parameters
+ self.algo.horizon.observation_horizon = 2
+ self.algo.horizon.action_horizon = 8
+ self.algo.horizon.prediction_horizon = 16
+
+ # UNet parameters
+ self.algo.unet.enabled = True
+ self.algo.unet.diffusion_step_embed_dim = 256
+ self.algo.unet.down_dims = [256,512,1024]
+ self.algo.unet.kernel_size = 5
+ self.algo.unet.n_groups = 8
+
+ # EMA parameters
+ self.algo.ema.enabled = True
+ self.algo.ema.power = 0.75
+
+ # Noise Scheduler
+ ## DDPM
+ self.algo.ddpm.enabled = True
+ self.algo.ddpm.num_train_timesteps = 100
+ self.algo.ddpm.num_inference_timesteps = 100
+ self.algo.ddpm.beta_schedule = 'squaredcos_cap_v2'
+ self.algo.ddpm.clip_sample = True
+ self.algo.ddpm.prediction_type = 'epsilon'
+
+ ## DDIM
+ self.algo.ddim.enabled = False
+ self.algo.ddim.num_train_timesteps = 100
+ self.algo.ddim.num_inference_timesteps = 10
+ self.algo.ddim.beta_schedule = 'squaredcos_cap_v2'
+ self.algo.ddim.clip_sample = True
+ self.algo.ddim.set_alpha_to_one = True
+ self.algo.ddim.steps_offset = 0
+ self.algo.ddim.prediction_type = 'epsilon'
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/gl_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/gl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..939103e65dd5f7519fb7be2c9fa1928d5b430bf2
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/gl_config.py
@@ -0,0 +1,89 @@
+"""
+Config for Goal Learning (sub-algorithm used by hierarchical models like HBC and IRIS).
+This class of model predicts (or samples) subgoal observations given a current observation.
+"""
+
+from robomimic.config.base_config import BaseConfig
+
+
+class GLConfig(BaseConfig):
+ ALGO_NAME = "gl"
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here.
+ """
+
+ # optimization parameters
+ self.algo.optim_params.goal_network.learning_rate.initial = 1e-4 # goal network learning rate
+ self.algo.optim_params.goal_network.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.goal_network.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.goal_network.regularization.L2 = 0.00
+
+ # subgoal definition: observation that is @subgoal_horizon number of timesteps in future from current observation
+ self.algo.subgoal_horizon = 10
+
+ # MLP size for deterministic goal network (unused if VAE is enabled)
+ self.algo.ae.planner_layer_dims = (300, 400)
+
+ # ================== VAE config ==================
+ self.algo.vae.enabled = True # set to true to use VAE network
+ self.algo.vae.latent_dim = 16 # VAE latent dimension
+ self.algo.vae.latent_clip = None # clip latent space when decoding (set to None to disable)
+ self.algo.vae.kl_weight = 1. # beta-VAE weight to scale KL loss relative to reconstruction loss in ELBO
+
+ # VAE decoder settings
+ self.algo.vae.decoder.is_conditioned = True # whether decoder should condition on observation
+ self.algo.vae.decoder.reconstruction_sum_across_elements = False # sum instead of mean for reconstruction loss
+
+ # VAE prior settings
+ self.algo.vae.prior.learn = False # learn Gaussian / GMM prior instead of N(0, 1)
+ self.algo.vae.prior.is_conditioned = False # whether to condition prior on observations
+ self.algo.vae.prior.use_gmm = False # whether to use GMM prior
+ self.algo.vae.prior.gmm_num_modes = 10 # number of GMM modes
+ self.algo.vae.prior.gmm_learn_weights = False # whether to learn GMM weights
+ self.algo.vae.prior.use_categorical = False # whether to use categorical prior
+ self.algo.vae.prior.categorical_dim = 10 # the number of categorical classes for each latent dimension
+ self.algo.vae.prior.categorical_gumbel_softmax_hard = False # use hard selection in forward pass
+ self.algo.vae.prior.categorical_init_temp = 1.0 # initial gumbel-softmax temp
+ self.algo.vae.prior.categorical_temp_anneal_step = 0.001 # linear temp annealing rate
+ self.algo.vae.prior.categorical_min_temp = 0.3 # lowest gumbel-softmax temp
+
+ self.algo.vae.encoder_layer_dims = (300, 400) # encoder MLP layer dimensions
+ self.algo.vae.decoder_layer_dims = (300, 400) # decoder MLP layer dimensions
+ self.algo.vae.prior_layer_dims = (300, 400) # prior MLP layer dimensions (if learning conditioned prior)
+
+ def observation_config(self):
+ """
+ Update from superclass to specify subgoal modalities.
+ """
+ super(GLConfig, self).observation_config()
+ self.observation.modalities.subgoal.low_dim = [ # specify low-dim subgoal observations for agent to predict
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object",
+ ]
+ self.observation.modalities.subgoal.rgb = [] # specify rgb image subgoal observations for agent to predict
+ self.observation.modalities.subgoal.depth = []
+ self.observation.modalities.subgoal.scan = []
+ self.observation.modalities.subgoal.do_not_lock_keys()
+
+ @property
+ def all_obs_keys(self):
+ """
+ Update from superclass to include subgoals.
+ """
+ # pool all modalities
+ return sorted(tuple(set([
+ obs_key for group in [
+ self.observation.modalities.obs.values(),
+ self.observation.modalities.goal.values(),
+ self.observation.modalities.subgoal.values(),
+ ]
+ for modality in group
+ for obs_key in modality
+ ])))
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/hbc_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/hbc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae65c9b85fc168dc65666392fd334810b622ab04
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/hbc_config.py
@@ -0,0 +1,96 @@
+"""
+Config for HBC algorithm.
+"""
+
+from robomimic.config.base_config import BaseConfig
+from robomimic.config.gl_config import GLConfig
+from robomimic.config.bc_config import BCConfig
+
+
+class HBCConfig(BaseConfig):
+ ALGO_NAME = "hbc"
+
+ def train_config(self):
+ """
+ Update from superclass to change default sequence length to load from dataset.
+ """
+ super(HBCConfig, self).train_config()
+ self.train.seq_length = 10 # length of experience sequence to fetch from the buffer
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here.
+ """
+
+ # One of ["separate", "cascade"]. In "separate" mode (default),
+ # the planner and actor are trained independently and then the planner subgoal predictions are
+ # used to condition the actor at test-time. In "cascade" mode, the actor is trained directly
+ # on planner subgoal predictions. In "actor_only" mode, only the actor is trained, and in
+ # "planner_only" mode, only the planner is trained.
+ self.algo.mode = "separate"
+ self.algo.actor_use_random_subgoals = False # whether to sample subgoal index from [1, subgoal_horizon]
+ self.algo.subgoal_update_interval = 10 # how frequently the subgoal should be updated at test-time
+
+
+ # ================== Latent Subgoal Config ==================
+ self.algo.latent_subgoal.enabled = False # if True, use VAE latent space as subgoals for actor, instead of reconstructions
+
+ # prior correction trick for actor and value training: instead of using encoder for
+ # transforming subgoals to latent subgoals, generate prior samples and choose
+ # the closest one to the encoder output
+ self.algo.latent_subgoal.prior_correction.enabled = False
+ self.algo.latent_subgoal.prior_correction.num_samples = 100
+
+ # ================== Planner Config ==================
+ self.algo.planner = GLConfig().algo # config for goal learning
+ # set subgoal horizon explicitly
+ self.algo.planner.subgoal_horizon = 10
+ # ensure VAE is used
+ self.algo.planner.vae.enabled = True
+
+ # ================== Actor Config ===================
+ self.algo.actor = BCConfig().algo
+ # use RNN
+ self.algo.actor.rnn.enabled = True
+ self.algo.actor.rnn.horizon = 10
+ # remove unused parts of BCConfig algo config
+ del self.algo.actor.gaussian
+ del self.algo.actor.gmm
+ del self.algo.actor.vae
+
+ def observation_config(self):
+ """
+ Update from superclass so that planner and actor each get their own observation config.
+ """
+ self.observation.planner = GLConfig().observation
+ self.observation.actor = BCConfig().observation
+
+ @property
+ def use_goals(self):
+ """
+ Update from superclass - planner goal modalities determine goal-conditioning
+ """
+ return len(
+ self.observation.planner.modalities.goal.low_dim +
+ self.observation.planner.modalities.goal.rgb) > 0
+
+ @property
+ def all_obs_keys(self):
+ """
+ Update from superclass to include modalities from planner and actor.
+ """
+ # pool all modalities
+ return sorted(tuple(set([
+ obs_key for group in [
+ self.observation.planner.modalities.obs.values(),
+ self.observation.planner.modalities.goal.values(),
+ self.observation.planner.modalities.subgoal.values(),
+ self.observation.actor.modalities.obs.values(),
+ self.observation.actor.modalities.goal.values(),
+ ]
+ for modality in group
+ for obs_key in modality
+ ])))
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/iql_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/iql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd603d1aa0183639971b16747c5020afa6d04fe3
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/iql_config.py
@@ -0,0 +1,73 @@
+"""
+Config for IQL algorithm.
+"""
+
+from robomimic.config.base_config import BaseConfig
+
+
+class IQLConfig(BaseConfig):
+ ALGO_NAME = "iql"
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here.
+ """
+ super(IQLConfig, self).algo_config()
+
+ # optimization parameters
+ self.algo.optim_params.critic.learning_rate.initial = 1e-4 # critic learning rate
+ self.algo.optim_params.critic.learning_rate.decay_factor = 0.0 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.critic.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.critic.regularization.L2 = 0.00 # L2 regularization strength
+
+ self.algo.optim_params.vf.learning_rate.initial = 1e-4 # vf learning rate
+ self.algo.optim_params.vf.learning_rate.decay_factor = 0.0 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.vf.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.vf.regularization.L2 = 0.00 # L2 regularization strength
+
+ self.algo.optim_params.actor.learning_rate.initial = 1e-4 # actor learning rate
+ self.algo.optim_params.actor.learning_rate.decay_factor = 0.0 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.actor.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.actor.regularization.L2 = 0.00 # L2 regularization strength
+
+ # target network related parameters
+ self.algo.discount = 0.99 # discount factor to use
+ self.algo.target_tau = 0.01 # update rate for target networks
+
+ # ================== Actor Network Config ===================
+ # Actor network settings
+ self.algo.actor.net.type = "gaussian" # Options are currently ["gaussian", "gmm"]
+
+ # Actor network settings - shared
+ self.algo.actor.net.common.std_activation = "softplus" # Activation to use for std output from policy net
+ self.algo.actor.net.common.low_noise_eval = True # Whether to use deterministic action sampling at eval stage
+ self.algo.actor.net.common.use_tanh = False # Whether to use tanh at output of actor network
+
+ # Actor network settings - gaussian
+ self.algo.actor.net.gaussian.init_last_fc_weight = 0.001 # If set, will override the initialization of the final fc layer to be uniformly sampled limited by this value
+ self.algo.actor.net.gaussian.init_std = 0.3 # Relative scaling factor for std from policy net
+ self.algo.actor.net.gaussian.fixed_std = False # Whether to learn std dev or not
+
+ self.algo.actor.net.gmm.num_modes = 5 # number of GMM modes
+ self.algo.actor.net.gmm.min_std = 0.0001 # minimum std output from network
+
+ self.algo.actor.layer_dims = (300, 400) # actor MLP layer dimensions
+
+ self.algo.actor.max_gradient_norm = None # L2 gradient clipping for actor
+
+ # ================== Critic Network Config ===================
+ # critic ensemble parameters
+ self.algo.critic.ensemble.n = 2 # number of Q networks in the ensemble
+ self.algo.critic.layer_dims = (300, 400) # critic MLP layer dimensions
+ self.algo.critic.use_huber = False # Huber Loss instead of L2 for critic
+ self.algo.critic.max_gradient_norm = None # L2 gradient clipping for actor
+
+ # ================== Adv Config ==============================
+ self.algo.adv.clip_adv_value = None # whether to clip raw advantage estimates
+ self.algo.adv.beta = 1.0 # temperature for operator
+ self.algo.adv.use_final_clip = True # whether to clip final weight calculations
+
+ self.algo.vf_quantile = 0.9 # quantile factor in quantile regression
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/iris_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/iris_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c03328cead61f1a977d76bb4b684613586c2a08c
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/iris_config.py
@@ -0,0 +1,99 @@
+"""
+Config for IRIS algorithm.
+"""
+
+from robomimic.config.bcq_config import BCQConfig
+from robomimic.config.gl_config import GLConfig
+from robomimic.config.bc_config import BCConfig
+from robomimic.config.hbc_config import HBCConfig
+
+
+class IRISConfig(HBCConfig):
+ ALGO_NAME = "iris"
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here.
+ """
+
+ # One of ["separate", "cascade"]. In "separate" mode (default),
+ # the planner and actor are trained independently and then the planner subgoal predictions are
+ # used to condition the actor at test-time. In "cascade" mode, the actor is trained directly
+ # on planner subgoal predictions. In "actor_only" mode, only the actor is trained, and in
+ # "planner_only" mode, only the planner is trained.
+ self.algo.mode = "separate"
+
+ self.algo.actor_use_random_subgoals = False # whether to sample subgoal index from [1, subgoal_horizon]
+ self.algo.subgoal_update_interval = 10 # how frequently the subgoal should be updated at test-time (usually matches train.seq_length)
+
+ # ================== Latent Subgoal Config ==================
+
+ # NOTE: latent subgoals are not supported by IRIS, but superclass expects this config
+ self.algo.latent_subgoal.enabled = False
+ self.algo.latent_subgoal.prior_correction.enabled = False
+ self.algo.latent_subgoal.prior_correction.num_samples = 100
+
+ # ================== Planner Config ==================
+
+ # The ValuePlanner planner component is a Goal Learning VAE model
+ self.algo.value_planner.planner = GLConfig().algo # config for goal learning
+ # set subgoal horizon explicitly
+ self.algo.value_planner.planner.subgoal_horizon = 10
+ # ensure VAE is used
+ self.algo.value_planner.planner.vae.enabled = True
+
+ # The ValuePlanner value component is a BCQ model
+ self.algo.value_planner.value = BCQConfig().algo
+ self.algo.value_planner.value.actor.enabled = False # ensure no BCQ actor
+ # number of subgoal samples to use for value planner
+ self.algo.value_planner.num_samples = 100
+
+ # ================== Actor Config ===================
+ self.algo.actor = BCConfig().algo
+ # use RNN
+ self.algo.actor.rnn.enabled = True
+ self.algo.actor.rnn.horizon = 10
+ # remove unused parts of BCConfig algo config
+ del self.algo.actor.gaussian
+ del self.algo.actor.gmm
+ del self.algo.actor.vae
+
+ def observation_config(self):
+ """
+ Update from superclass so that value planner and actor each get their own obs config.
+ """
+ self.observation.value_planner.planner = GLConfig().observation
+ self.observation.value_planner.value = BCQConfig().observation
+ self.observation.actor = BCConfig().observation
+
+ @property
+ def use_goals(self):
+ """
+ Update from superclass - value planner goal modalities determine goal-conditioning.
+ """
+ return len(
+ self.observation.value_planner.planner.modalities.goal.low_dim +
+ self.observation.value_planner.planner.modalities.goal.rgb) > 0
+
+ @property
+ def all_obs_keys(self):
+ """
+ Update from superclass to include modalities from value planner and actor.
+ """
+ # pool all modalities
+ return sorted(tuple(set([
+ obs_key for group in [
+ self.observation.value_planner.planner.modalities.obs.values(),
+ self.observation.value_planner.planner.modalities.goal.values(),
+ self.observation.value_planner.planner.modalities.subgoal.values(),
+ self.observation.value_planner.value.modalities.obs.values(),
+ self.observation.value_planner.value.modalities.goal.values(),
+ self.observation.actor.modalities.obs.values(),
+ self.observation.actor.modalities.goal.values(),
+ ]
+ for modality in group
+ for obs_key in modality
+ ])))
diff --git a/phantom/submodules/phantom-robomimic/robomimic/config/td3_bc_config.py b/phantom/submodules/phantom-robomimic/robomimic/config/td3_bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..036a2591a91b4a4f5da4e2415dd035117e587900
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/config/td3_bc_config.py
@@ -0,0 +1,111 @@
+"""
+Config for TD3_BC.
+"""
+
+from robomimic.config.base_config import BaseConfig
+
+
+class TD3_BCConfig(BaseConfig):
+ ALGO_NAME = "td3_bc"
+
+ def experiment_config(self):
+ """
+ Update from subclass to set paper defaults for gym envs.
+ """
+ super(TD3_BCConfig, self).experiment_config()
+
+ # no validation and no video rendering
+ self.experiment.validate = False
+ self.experiment.render_video = False
+
+ # save 10 checkpoints throughout training
+ self.experiment.save.every_n_epochs = 20
+
+ # save models that achieve best rollout return instead of best success rate
+ self.experiment.save.on_best_rollout_return = True
+ self.experiment.save.on_best_rollout_success_rate = False
+
+ # epoch definition - 5000 gradient steps per epoch, with 200 epochs = 1M gradient steps, and eval every 1 epochs
+ self.experiment.epoch_every_n_steps = 5000
+
+ # evaluate with normal environment rollouts
+ self.experiment.rollout.enabled = True
+ self.experiment.rollout.n = 50 # paper uses 10, but we can afford to do 50
+ self.experiment.rollout.horizon = 1000
+ self.experiment.rollout.rate = 1 # rollout every epoch to match paper
+
+ def train_config(self):
+ """
+ Update from subclass to set paper defaults for gym envs.
+ """
+ super(TD3_BCConfig, self).train_config()
+
+ # update to normalize observations
+ self.train.hdf5_normalize_obs = True
+
+ # increase batch size to 256
+ self.train.batch_size = 256
+
+ # 200 epochs, with each epoch lasting 5000 gradient steps, for 1M total steps
+ self.train.num_epochs = 200
+
+ def algo_config(self):
+ """
+ This function populates the `config.algo` attribute of the config, and is given to the
+ `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
+ argument to the constructor. Any parameter that an algorithm needs to determine its
+ training and test-time behavior should be populated here.
+ """
+
+ # optimization parameters
+ self.algo.optim_params.critic.learning_rate.initial = 3e-4 # critic learning rate
+ self.algo.optim_params.critic.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.critic.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.critic.regularization.L2 = 0.00 # L2 regularization strength
+ self.algo.optim_params.critic.start_epoch = -1 # number of epochs before starting critic training (-1 means start right away)
+ self.algo.optim_params.critic.end_epoch = -1 # number of epochs before ending critic training (-1 means start right away)
+
+ self.algo.optim_params.actor.learning_rate.initial = 3e-4 # actor learning rate
+ self.algo.optim_params.actor.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
+ self.algo.optim_params.actor.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
+ self.algo.optim_params.actor.regularization.L2 = 0.00 # L2 regularization strength
+ self.algo.optim_params.actor.start_epoch = -1 # number of epochs before starting actor training (-1 means start right away)
+ self.algo.optim_params.actor.end_epoch = -1 # number of epochs before ending actor training (-1 means start right away)
+
+ # alpha value - for weighting critic loss vs. BC loss
+ self.algo.alpha = 2.5
+
+ # target network related parameters
+ self.algo.discount = 0.99 # discount factor to use
+ self.algo.n_step = 1 # for using n-step returns in TD-updates
+ self.algo.target_tau = 0.005 # update rate for target networks
+ self.algo.infinite_horizon = False # if True, scale terminal rewards by 1 / (1 - discount) to treat as infinite horizon
+
+ # ================== Critic Network Config ===================
+ self.algo.critic.use_huber = False # Huber Loss instead of L2 for critic
+ self.algo.critic.max_gradient_norm = None # L2 gradient clipping for critic (None to use no clipping)
+ self.algo.critic.value_bounds = None # optional 2-tuple to ensure lower and upper bound on value estimates
+
+ # critic ensemble parameters (TD3 trick)
+ self.algo.critic.ensemble.n = 2 # number of Q networks in the ensemble
+ self.algo.critic.ensemble.weight = 1.0 # weighting for mixing min and max for target Q value
+
+ self.algo.critic.layer_dims = (256, 256) # size of critic MLP
+
+ # ================== Actor Network Config ===================
+
+ # update actor and target networks every n gradients steps for each critic gradient step
+ self.algo.actor.update_freq = 2
+
+ # exploration noise used to form target action for Q-update - clipped Gaussian noise
+ self.algo.actor.noise_std = 0.2 # zero-mean gaussian noise with this std is applied to actions
+ self.algo.actor.noise_clip = 0.5 # noise is clipped in each dimension to (-noise_clip, noise_clip)
+
+ self.algo.actor.layer_dims = (256, 256) # size of actor MLP
+
+ def observation_config(self):
+ """
+ Update from superclass to use flat observations from gym envs.
+ """
+ super(TD3_BCConfig, self).observation_config()
+ self.observation.modalities.obs.low_dim = ["flat"]
diff --git a/phantom/submodules/phantom-robomimic/robomimic/envs/__init__.py b/phantom/submodules/phantom-robomimic/robomimic/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/submodules/phantom-robomimic/robomimic/envs/env_base.py b/phantom/submodules/phantom-robomimic/robomimic/envs/env_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1006ea1bf4f29357f2b32127fd6cf268697c948
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/envs/env_base.py
@@ -0,0 +1,245 @@
+"""
+This file contains the base class for environment wrappers that are used
+to provide a standardized environment API for training policies and interacting
+with metadata present in datasets.
+"""
+import abc
+
+
+class EnvType:
+ """
+ Holds environment types - one per environment class.
+ These act as identifiers for different environments.
+ """
+ ROBOSUITE_TYPE = 1
+ GYM_TYPE = 2
+ IG_MOMART_TYPE = 3
+ REAL_TYPE = 6
+ GPRS_REAL_TYPE = 7
+ REAL_UR5E_TYPE = 8
+ REAL_KINOVA_TYPE = 9
+
+
+class EnvBase(abc.ABC):
+ """A base class method for environments used by this repo."""
+ @abc.abstractmethod
+ def __init__(
+ self,
+ env_name,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=False,
+ use_depth_obs=False,
+ postprocess_visual_obs=True,
+ **kwargs,
+ ):
+ """
+ Args:
+ env_name (str): name of environment. Only needs to be provided if making a different
+ environment from the one in @env_meta.
+
+ render (bool): if True, environment supports on-screen rendering
+
+ render_offscreen (bool): if True, environment supports off-screen rendering. This
+ is forced to be True if @env_meta["use_images"] is True.
+
+ use_image_obs (bool): if True, environment is expected to render rgb image observations
+ on every env.step call. Set this to False for efficiency reasons, if image
+ observations are not required.
+
+ use_depth_obs (bool): if True, environment is expected to render depth image observations
+ on every env.step call. Set this to False for efficiency reasons, if depth
+ observations are not required.
+
+ postprocess_visual_obs (bool): if True, postprocess image observations
+ to prepare for learning. This should only be False when extracting observations
+ for saving to a dataset (to save space on RGB images for example).
+ """
+ return
+
+ @abc.abstractmethod
+ def step(self, action):
+ """
+ Step in the environment with an action.
+
+ Args:
+ action (np.array): action to take
+
+ Returns:
+ observation (dict): new observation dictionary
+ reward (float): reward for this step
+ done (bool): whether the task is done
+ info (dict): extra information
+ """
+ return
+
+ @abc.abstractmethod
+ def reset(self):
+ """
+ Reset environment.
+
+ Returns:
+ observation (dict): initial observation dictionary.
+ """
+ return
+
+ @abc.abstractmethod
+ def reset_to(self, state):
+ """
+ Reset to a specific simulator state.
+
+ Args:
+ state (dict): current simulator state
+
+ Returns:
+ observation (dict): observation dictionary after setting the simulator state
+ """
+ return
+
+ @abc.abstractmethod
+ def render(self, mode="human", height=None, width=None, camera_name=None):
+ """Render"""
+ return
+
+ @abc.abstractmethod
+ def get_observation(self):
+ """Get environment observation"""
+ return
+
+ @abc.abstractmethod
+ def get_state(self):
+ """Get environment simulator state, compatible with @reset_to"""
+ return
+
+ @abc.abstractmethod
+ def get_reward(self):
+ """
+ Get current reward.
+ """
+ return
+
+ @abc.abstractmethod
+ def get_goal(self):
+ """
+ Get goal observation. Not all environments support this.
+ """
+ return
+
+ @abc.abstractmethod
+ def set_goal(self, **kwargs):
+ """
+ Set goal observation with external specification. Not all environments support this.
+ """
+ return
+
+ @abc.abstractmethod
+ def is_done(self):
+ """
+ Check if the task is done (not necessarily successful).
+ """
+ return
+
+ @abc.abstractmethod
+ def is_success(self):
+ """
+ Check if the task condition(s) is reached. Should return a dictionary
+ { str: bool } with at least a "task" key for the overall task success,
+ and additional optional keys corresponding to other task criteria.
+ """
+ return
+
+ @property
+ @abc.abstractmethod
+ def action_dimension(self):
+ """
+ Returns dimension of actions (int).
+ """
+ return
+
+ @property
+ @abc.abstractmethod
+ def name(self):
+ """
+ Returns name of environment name (str).
+ """
+ return
+
+ @property
+ @abc.abstractmethod
+ def type(self):
+ """
+ Returns environment type (int) for this kind of environment.
+ This helps identify this env class.
+ """
+ return
+
+ @property
+ def version(self):
+ """
+ Returns version of environment (str).
+ This is not an abstract method, some subclasses do not implement it
+ """
+ return None
+
+ @abc.abstractmethod
+ def serialize(self):
+ """
+ Save all information needed to re-instantiate this environment in a dictionary.
+ This is the same as @env_meta - environment metadata stored in hdf5 datasets,
+ and used in utils/env_utils.py.
+ """
+ return
+
+ @classmethod
+ @abc.abstractmethod
+ def create_for_data_processing(
+ cls,
+ camera_names,
+ camera_height,
+ camera_width,
+ reward_shaping,
+ render=None,
+ render_offscreen=None,
+ use_image_obs=None,
+ use_depth_obs=None,
+ **kwargs,
+ ):
+ """
+ Create environment for processing datasets, which includes extracting
+ observations, labeling dense / sparse rewards, and annotating dones in
+ transitions.
+
+ Args:
+ camera_names ([str]): list of camera names that correspond to image observations
+ camera_height (int): camera height for all cameras
+ camera_width (int): camera width for all cameras
+ reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards
+ render (bool or None): optionally override rendering behavior. Defaults to False.
+ render_offscreen (bool or None): optionally override rendering behavior. The default value is True if
+ @camera_names is non-empty, False otherwise.
+ use_image_obs (bool or None): optionally override rendering behavior. The default value is True if
+ @camera_names is non-empty, False otherwise.
+ use_depth_obs (bool): if True, use depth observations
+
+ Returns:
+ env (EnvBase instance)
+ """
+ return
+
+ @property
+ @abc.abstractmethod
+ def rollout_exceptions(self):
+ """
+ Return tuple of exceptions to except when doing rollouts. This is useful to ensure
+ that the entire training run doesn't crash because of a bad policy that causes unstable
+ simulation computations.
+ """
+ return
+
+ @property
+ @abc.abstractmethod
+ def base_env(self):
+ """
+ Grabs base simulation environment.
+ """
+ return
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/envs/env_gym.py b/phantom/submodules/phantom-robomimic/robomimic/envs/env_gym.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b56d1ebb3be670c8e2207fa6afcdf4ee1ec5190
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/envs/env_gym.py
@@ -0,0 +1,267 @@
+"""
+This file contains the gym environment wrapper that is used
+to provide a standardized environment API for training policies and interacting
+with metadata present in datasets.
+"""
+import json
+import numpy as np
+from copy import deepcopy
+
+import gym
+try:
+ import d4rl
+except:
+ print("WARNING: could not load d4rl environments!")
+
+import robomimic.envs.env_base as EB
+import robomimic.utils.obs_utils as ObsUtils
+
+
+class EnvGym(EB.EnvBase):
+ """Wrapper class for gym"""
+ def __init__(
+ self,
+ env_name,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=False,
+ use_depth_obs=False,
+ postprocess_visual_obs=True,
+ **kwargs,
+ ):
+ """
+ Args:
+ env_name (str): name of environment. Only needs to be provided if making a different
+ environment from the one in @env_meta.
+
+ render (bool): ignored - gym envs always support on-screen rendering
+
+ render_offscreen (bool): ignored - gym envs always support off-screen rendering
+
+ use_image_obs (bool): ignored - gym envs don't typically use images
+
+ postprocess_visual_obs (bool): ignored - gym envs don't typically use images
+ """
+ self._init_kwargs = deepcopy(kwargs)
+ self._env_name = env_name
+ self._current_obs = None
+ self._current_reward = None
+ self._current_done = None
+ self._done = None
+ self.env = gym.make(env_name, **kwargs)
+
+ def step(self, action):
+ """
+ Step in the environment with an action.
+
+ Args:
+ action (np.array): action to take
+
+ Returns:
+ observation (dict): new observation dictionary
+ reward (float): reward for this step
+ done (bool): whether the task is done
+ info (dict): extra information
+ """
+ obs, reward, done, info = self.env.step(action)
+ self._current_obs = obs
+ self._current_reward = reward
+ self._current_done = done
+ return self.get_observation(obs), reward, self.is_done(), info
+
+ def reset(self):
+ """
+ Reset environment.
+
+ Returns:
+ observation (dict): initial observation dictionary.
+ """
+ self._current_obs = self.env.reset()
+ self._current_reward = None
+ self._current_done = None
+ return self.get_observation(self._current_obs)
+
+ def reset_to(self, state):
+ """
+ Reset to a specific simulator state.
+
+ Args:
+ state (dict): current simulator state that contains:
+ - states (np.ndarray): initial state of the mujoco environment
+
+ Returns:
+ observation (dict): observation dictionary after setting the simulator state
+ """
+ if hasattr(self.env.unwrapped.sim, "set_state_from_flattened"):
+ self.env.unwrapped.sim.set_state_from_flattened(state["states"])
+ self.env.unwrapped.sim.forward()
+ return { "flat" : self.env.unwrapped._get_obs() }
+ else:
+ raise NotImplementedError
+
+ def render(self, mode="human", height=None, width=None, camera_name=None, **kwargs):
+ """
+ Render from simulation to either an on-screen window or off-screen to RGB array.
+
+ Args:
+ mode (str): pass "human" for on-screen rendering or "rgb_array" for off-screen rendering
+ height (int): height of image to render - only used if mode is "rgb_array"
+ width (int): width of image to render - only used if mode is "rgb_array"
+ """
+ if mode =="human":
+ return self.env.render(mode=mode, **kwargs)
+ if mode == "rgb_array":
+ return self.env.render(mode="rgb_array", height=height, width=width)
+ else:
+ raise NotImplementedError("mode={} is not implemented".format(mode))
+
+ def get_observation(self, obs=None):
+ """
+ Get current environment observation dictionary.
+
+ Args:
+ ob (np.array): current flat observation vector to wrap and provide as a dictionary.
+ If not provided, uses self._current_obs.
+ """
+ if obs is None:
+ assert self._current_obs is not None
+ obs = self._current_obs
+ return { "flat" : np.copy(obs) }
+
+ def get_state(self):
+ """
+ Get current environment simulator state as a dictionary. Should be compatible with @reset_to.
+ """
+ # NOTE: assumes MuJoCo gym task!
+ xml = self.env.sim.model.get_xml() # model xml file
+ state = np.array(self.env.sim.get_state().flatten()) # simulator state
+ return dict(model=xml, states=state)
+
+ def get_reward(self):
+ """
+ Get current reward.
+ """
+ assert self._current_reward is not None
+ return self._current_reward
+
+ def get_goal(self):
+ """
+ Get goal observation. Not all environments support this.
+ """
+ raise NotImplementedError
+
+ def set_goal(self, **kwargs):
+ """
+ Set goal observation with external specification. Not all environments support this.
+ """
+ raise NotImplementedError
+
+ def is_done(self):
+ """
+ Check if the task is done (not necessarily successful).
+ """
+ assert self._current_done is not None
+ return self._current_done
+
+ def is_success(self):
+ """
+ Check if the task condition(s) is reached. Should return a dictionary
+ { str: bool } with at least a "task" key for the overall task success,
+ and additional optional keys corresponding to other task criteria.
+ """
+ if hasattr(self.env.unwrapped, "_check_success"):
+ return self.env.unwrapped._check_success()
+
+ # gym envs generally don't check task success - we only compare returns
+ return { "task" : False }
+
+ @property
+ def action_dimension(self):
+ """
+ Returns dimension of actions (int).
+ """
+ return self.env.action_space.shape[0]
+
+ @property
+ def name(self):
+ """
+ Returns name of environment name (str).
+ """
+ return self._env_name
+
+ @property
+ def type(self):
+ """
+ Returns environment type (int) for this kind of environment.
+ This helps identify this env class.
+ """
+ return EB.EnvType.GYM_TYPE
+
+ def serialize(self):
+ """
+ Save all information needed to re-instantiate this environment in a dictionary.
+ This is the same as @env_meta - environment metadata stored in hdf5 datasets,
+ and used in utils/env_utils.py.
+ """
+ return dict(env_name=self.name, type=self.type, env_kwargs=deepcopy(self._init_kwargs))
+
+ @classmethod
+ def create_for_data_processing(
+ cls,
+ env_name,
+ camera_names,
+ camera_height,
+ camera_width,
+ reward_shaping,
+ render=None,
+ render_offscreen=None,
+ use_image_obs=None,
+ use_depth_obs=None,
+ **kwargs,
+ ):
+ """
+ Create environment for processing datasets, which includes extracting
+ observations, labeling dense / sparse rewards, and annotating dones in
+ transitions. For gym environments, input arguments (other than @env_name)
+ are ignored, since environments are mostly pre-configured.
+
+ Args:
+ env_name (str): name of gym environment to create
+
+ Returns:
+ env (EnvGym instance)
+ """
+
+ # make sure to initialize obs utils so it knows which modalities are image modalities.
+ # For currently supported gym tasks, there are no image observations.
+ obs_modality_specs = {
+ "obs": {
+ "low_dim": ["flat"],
+ "rgb": [],
+ }
+ }
+ ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs)
+
+ return cls(env_name=env_name, **kwargs)
+
+ @property
+ def rollout_exceptions(self):
+ """
+ Return tuple of exceptions to except when doing rollouts. This is useful to ensure
+ that the entire training run doesn't crash because of a bad policy that causes unstable
+ simulation computations.
+ """
+ return ()
+
+ @property
+ def base_env(self):
+ """
+ Grabs base simulation environment.
+ """
+ return self.env
+
+ def __repr__(self):
+ """
+ Pretty-print env description.
+ """
+ return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/envs/env_ig_momart.py b/phantom/submodules/phantom-robomimic/robomimic/envs/env_ig_momart.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd0a0db9df116ea57fe75cad5c526a07bb08e81d
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/envs/env_ig_momart.py
@@ -0,0 +1,414 @@
+"""
+Wrapper environment class to enable using iGibson-based environments used in the MOMART paper
+"""
+
+from copy import deepcopy
+import numpy as np
+import json
+
+import pybullet as p
+import gibson2
+from gibson2.envs.semantic_organize_and_fetch import SemanticOrganizeAndFetch
+from gibson2.utils.custom_utils import ObjectConfig
+import gibson2.external.pybullet_tools.utils as PBU
+import tempfile
+import os
+import yaml
+import cv2
+
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.envs.env_base as EB
+
+
+# TODO: Once iG 2.0 is more stable, automate available environments, similar to robosuite
+ENV_MAPPING = {
+ "SemanticOrganizeAndFetch": SemanticOrganizeAndFetch,
+}
+
+
+class EnvGibsonMOMART(EB.EnvBase):
+ """
+ Wrapper class for gibson environments (https://github.com/StanfordVL/iGibson) specifically compatible with
+ MoMaRT datasets
+ """
+ def __init__(
+ self,
+ env_name,
+ ig_config,
+ postprocess_visual_obs=True,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=False,
+ use_depth_obs=False,
+ image_height=None,
+ image_width=None,
+ physics_timestep=1./240.,
+ action_timestep=1./20.,
+ **kwargs,
+ ):
+ """
+ Args:
+ ig_config (dict): YAML configuration to use for iGibson, as a dict
+
+ postprocess_visual_obs (bool): if True, postprocess image observations
+ to prepare for learning
+
+ render (bool): if True, environment supports on-screen rendering
+
+ render_offscreen (bool): if True, environment supports off-screen rendering. This
+ is forced to be True if @use_image_obs is True.
+
+ use_image_obs (bool): if True, environment is expected to render rgb image observations
+ on every env.step call. Set this to False for efficiency reasons, if image
+ observations are not required.
+
+ use_depth_obs (bool): if True, environment is expected to render depth image observations
+ on every env.step call. Set this to False for efficiency reasons, if depth
+ observations are not required.
+
+ render_mode (str): How to run simulation rendering. Options are {"pbgui", "iggui", or "headless"}
+
+ image_height (int): If specified, overrides internal iG image height when rendering
+
+ image_width (int): If specified, overrides internal iG image width when rendering
+
+ physics_timestep (float): Pybullet physics timestep to use
+
+ action_timestep (float): Action timestep to use for robot in simulation
+
+ kwargs (unrolled dict): Any args to substitute in the ig_configuration
+ """
+ self._env_name = env_name
+ self.ig_config = deepcopy(ig_config)
+ self.postprocess_visual_obs = postprocess_visual_obs
+ self._init_kwargs = kwargs
+
+ # Determine rendering mode
+ self.render_mode = "iggui" if render else "headless"
+ self.render_onscreen = render
+
+ # Make sure rgb is part of obs in ig config
+ self.ig_config["output"] = list(set(self.ig_config["output"] + ["rgb"]))
+
+ # Warn user that iG always uses a renderer
+ if (not render) and (not render_offscreen):
+ print("WARNING: iGibson always uses a renderer -- using headless by default.")
+
+ # Update ig config
+ for k, v in kwargs.items():
+ assert k in self.ig_config, f"Got unknown ig configuration key {k}!"
+ self.ig_config[k] = v
+
+ # Set rendering values
+ self.obs_img_height = image_height if image_height is not None else self.ig_config.get("obs_image_height", 120)
+ self.obs_img_width = image_width if image_width is not None else self.ig_config.get("obs_image_width", 120)
+
+ # Get class to create
+ envClass = ENV_MAPPING.get(self._env_name, None)
+
+ # Make sure we have a valid environment class
+ assert envClass is not None, "No valid environment for the requested task was found!"
+
+ # Set device idx for rendering
+ # ensure that we select the correct GPU device for rendering by testing for EGL rendering
+ # NOTE: this package should be installed from this link (https://github.com/StanfordVL/egl_probe)
+ import egl_probe
+ device_idx = 0
+ valid_gpu_devices = egl_probe.get_available_devices()
+ if len(valid_gpu_devices) > 0:
+ device_idx = valid_gpu_devices[0]
+
+ # Create environment
+ self.env = envClass(
+ config_file=deepcopy(self.ig_config),
+ mode=self.render_mode,
+ physics_timestep=physics_timestep,
+ action_timestep=action_timestep,
+ device_idx=device_idx,
+ )
+
+ # If we have a viewer, make sure to remove all bodies belonging to the visual markers
+ self.exclude_body_ids = [] # Bodies to exclude when saving state
+ if self.env.simulator.viewer is not None:
+ self.exclude_body_ids.append(self.env.simulator.viewer.constraint_marker.body_id)
+ self.exclude_body_ids.append(self.env.simulator.viewer.constraint_marker2.body_id)
+
+ def step(self, action):
+ """
+ Step in the environment with an action
+
+ Args:
+ action: action to take
+
+ Returns:
+ observation: new observation
+ reward: step reward
+ done: whether the task is done
+ info: extra information
+ """
+ obs, r, done, info = self.env.step(action)
+ obs = self.get_observation(obs)
+ return obs, r, self.is_done(), info
+
+ def reset(self):
+ """Reset environment"""
+ di = self.env.reset()
+ return self.get_observation(di)
+
+ def reset_to(self, state):
+ """
+ Reset to a specific state
+ Args:
+ state (dict): contains:
+ - states (np.ndarray): initial state of the mujoco environment
+ - goal (dict): goal components to reset
+ Returns:
+ new observation
+ """
+ if "states" in state:
+ self.env.reset_to(state["states"], exclude=self.exclude_body_ids)
+
+ if "goal" in state:
+ self.set_goal(**state["goal"])
+
+ # Return obs
+ return self.get_observation()
+
+ def render(self, mode="human", camera_name="rgb", height=None, width=None):
+ """
+ Render
+
+ Args:
+ mode (str): Mode(s) to render. Options are either 'human' (rendering onscreen) or 'rgb' (rendering to
+ frames offscreen)
+ camera_name (str): Name of the camera to use -- valid options are "rgb" or "rgb_wrist"
+ height (int): If specified with width, resizes the rendered image to this height
+ width (int): If specified with height, resizes the rendered image to this width
+
+ Returns:
+ array or None: If rendering to frame, returns the rendered frame. Otherwise, returns None
+ """
+ # Only robotview camera is currently supported
+ assert camera_name in {"rgb", "rgb_wrist"}, \
+ f"Only rgb, rgb_wrist cameras currently supported, got {camera_name}."
+
+ if mode == "human":
+ assert self.render_onscreen, "Rendering has not been enabled for onscreen!"
+ self.env.simulator.sync()
+ else:
+ assert self.env.simulator.renderer is not None, "No renderer enabled for this env!"
+
+ frame = self.env.sensors["vision"].get_obs(self.env)[camera_name]
+
+ # Reshape all frames
+ if height is not None and width is not None:
+ frame = cv2.resize(frame, dsize=(height, width), interpolation=cv2.INTER_CUBIC)
+ return frame
+
+ def resize_obs_frame(self, frame):
+ """
+ Resizes frame to be internal height and width values
+ """
+ return cv2.resize(frame, dsize=(self.obs_img_width, self.obs_img_height), interpolation=cv2.INTER_CUBIC)
+
+ def get_observation(self, di=None):
+ """Get environment observation"""
+ if di is None:
+ di = self.env.get_state()
+ ret = {}
+ for k in di:
+ # RGB Images
+ if "rgb" in k:
+ ret[k] = di[k]
+ # ret[k] = np.transpose(di[k], (2, 0, 1))
+ if self.postprocess_visual_obs:
+ ret[k] = ObsUtils.process_obs(obs=self.resize_obs_frame(ret[k]), obs_key=k)
+
+ # Depth images
+ elif "depth" in k:
+ # ret[k] = np.transpose(di[k], (2, 0, 1))
+ # Values can be corrupted (negative or > 1.0, so we clip values)
+ ret[k] = np.clip(di[k], 0.0, 1.0)
+ if self.postprocess_visual_obs:
+ ret[k] = ObsUtils.process_obs(obs=self.resize_obs_frame(ret[k])[..., None], obs_key=k)
+
+ # Segmentation Images
+ elif "seg" in k:
+ ret[k] = di[k][..., None]
+ if self.postprocess_visual_obs:
+ ret[k] = ObsUtils.process_obs(obs=self.resize_obs_frame(ret[k]), obs_key=k)
+
+ # Scans
+ elif "scan" in k:
+ ret[k] = np.transpose(np.array(di[k]), axes=(1, 0))
+
+ # Compose proprio obs
+ proprio_obs = di["proprio"]
+
+ # Compute intermediate values
+ lin_vel = np.linalg.norm(proprio_obs["base_lin_vel"][:2])
+ ang_vel = proprio_obs["base_ang_vel"][2]
+
+ ret["proprio"] = np.concatenate([
+ proprio_obs["head_joint_pos"],
+ proprio_obs["grasped"],
+ proprio_obs["eef_pos"],
+ proprio_obs["eef_quat"],
+ ])
+
+ # Proprio info that's only relevant for navigation
+ ret["proprio_nav"] = np.concatenate([
+ [lin_vel],
+ [ang_vel],
+ ])
+
+ # Compose task obs
+ ret["object"] = np.concatenate([
+ np.array(di["task_obs"]["object-state"]),
+ ])
+
+ # Add ground truth navigational state
+ ret["gt_nav"] = np.concatenate([
+ proprio_obs["base_pos"][:2],
+ [np.sin(proprio_obs["base_rpy"][2])],
+ [np.cos(proprio_obs["base_rpy"][2])],
+ ])
+
+ return ret
+
+ def sync_task(self):
+ """
+ Method to synchronize iG task, since we're not actually resetting the env but instead setting states directly.
+ Should only be called after resetting the initial state of an episode
+ """
+ self.env.task.update_target_object_init_pos()
+ self.env.task.update_location_info()
+
+ def set_task_conditions(self, task_conditions):
+ """
+ Method to override task conditions (e.g.: target object), useful in cases such as playing back
+ from demonstrations
+
+ Args:
+ task_conditions (dict): Keyword-mapped arguments to pass to task instance to set internally
+ """
+ self.env.set_task_conditions(task_conditions)
+
+ def get_state(self):
+ """Get iG flattened state"""
+ return {"states": PBU.WorldSaver(exclude_body_ids=self.exclude_body_ids).serialize()}
+
+ def get_reward(self):
+ return self.env.task.get_reward(self.env)[0]
+ # return float(self.is_success()["task"])
+
+ def get_goal(self):
+ """Get goal specification"""
+ # No support yet in iG
+ raise NotImplementedError
+
+ def set_goal(self, **kwargs):
+ """Set env target with external specification"""
+ # No support yet in iG
+ raise NotImplementedError
+
+ def is_done(self):
+ """Check if the agent is done (not necessarily successful)."""
+ return False
+
+ def is_success(self):
+ """
+ Check if the task condition(s) is reached. Should return a dictionary
+ { str: bool } with at least a "task" key for the overall task success,
+ and additional optional keys corresponding to other task criteria.
+ """
+ succ = self.env.check_success()
+ if isinstance(succ, dict):
+ assert "task" in succ
+ return succ
+ return { "task" : succ }
+
+ @classmethod
+ def create_for_data_processing(
+ cls,
+ env_name,
+ camera_names,
+ camera_height,
+ camera_width,
+ reward_shaping,
+ render=None,
+ render_offscreen=None,
+ use_image_obs=None,
+ use_depth_obs=None,
+ **kwargs,
+ ):
+ """
+ Create environment for processing datasets, which includes extracting
+ observations, labeling dense / sparse rewards, and annotating dones in
+ transitions.
+
+ Args:
+ env_name (str): name of environment
+ camera_names (list of str): list of camera names that correspond to image observations
+ camera_height (int): camera height for all cameras
+ camera_width (int): camera width for all cameras
+ reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards
+ render (bool or None): optionally override rendering behavior
+ render_offscreen (bool or None): optionally override rendering behavior
+ use_image_obs (bool or None): optionally override rendering behavior
+ """
+ has_camera = (len(camera_names) > 0)
+
+ # note that @postprocess_visual_obs is False since this env's images will be written to a dataset
+ return cls(
+ env_name=env_name,
+ render=(False if render is None else render),
+ render_offscreen=(has_camera if render_offscreen is None else render_offscreen),
+ use_image_obs=(has_camera if use_image_obs is None else use_image_obs),
+ postprocess_visual_obs=False,
+ image_height=camera_height,
+ image_width=camera_width,
+ **kwargs,
+ )
+
+ @property
+ def action_dimension(self):
+ """Action dimension"""
+ return self.env.robots[0].action_dim
+
+ @property
+ def name(self):
+ """Environment name"""
+ return self._env_name
+
+ @property
+ def type(self):
+ """Environment type"""
+ return EB.EnvType.IG_MOMART_TYPE
+
+ def serialize(self):
+ """Serialize to dictionary"""
+ return dict(env_name=self.name, type=self.type,
+ ig_config=self.ig_config,
+ env_kwargs=deepcopy(self._init_kwargs))
+
+ @classmethod
+ def deserialize(cls, info, postprocess_visual_obs=True):
+ """Create environment with external info"""
+ return cls(env_name=info["env_name"], ig_config=info["ig_config"], postprocess_visual_obs=postprocess_visual_obs, **info["env_kwargs"])
+
+ @property
+ def rollout_exceptions(self):
+ """Return tuple of exceptions to except when doing rollouts"""
+ return (RuntimeError)
+
+ @property
+ def base_env(self):
+ """
+ Grabs base simulation environment.
+ """
+ return self.env
+
+ def __repr__(self):
+ return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4) + \
+ "\niGibson Config: \n" + json.dumps(self.ig_config, sort_keys=True, indent=4)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/envs/env_real_panda.py b/phantom/submodules/phantom-robomimic/robomimic/envs/env_real_panda.py
new file mode 100644
index 0000000000000000000000000000000000000000..c59a979724fe946c0e56c4e9f0e8eab6b8d03214
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/envs/env_real_panda.py
@@ -0,0 +1,448 @@
+"""
+This file contains the base class for environment wrappers that are used
+to provide a standardized environment API for training policies and interacting
+with metadata present in datasets.
+"""
+import time
+import json
+import sys
+import numpy as np
+from copy import deepcopy
+
+import cv2
+
+import RobotTeleop
+import RobotTeleop.utils as U
+from RobotTeleop.utils import Rate, RateMeasure, Timers
+
+import robomimic.envs.env_base as EB
+import robomimic.utils.obs_utils as ObsUtils
+
+class EnvRealPanda(EB.EnvBase):
+ """Wrapper class for real panda environment"""
+ def __init__(
+ self,
+ env_name,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=True,
+ use_depth_obs=False,
+ postprocess_visual_obs=True,
+ control_freq=20.,
+ action_scale=None,
+ camera_names_to_sizes=None,
+ init_ros_node=True,
+ publish_target_pose=False,
+ fake_controller=False,
+ use_moveit=True,
+ ):
+ """
+ Args:
+ env_name (str): name of environment.
+
+ render (bool): ignored - on-screen rendering is not supported
+
+ render_offscreen (bool): ignored - image observations are supplied by default
+
+ use_image_obs (bool): ignored - image observations are used by default.
+
+ postprocess_visual_obs (bool): if True, postprocess image observations
+ to prepare for learning. This should only be False when extracting observations
+ for saving to a dataset (to save space on RGB images for example).
+
+ control_freq (int): real-world control frequency to try and enforce through rate-limiting
+
+ action_scale (list): list of 7 numbers for what the -1 and 1 action in each dimension corresponds to
+ for the physical robot action space
+
+ camera_names_to_sizes (dict): dictionary that maps camera names to tuple of image height and width
+ to return
+ """
+ self._env_name = env_name
+ self.postprocess_visual_obs = postprocess_visual_obs
+ self.control_freq = control_freq
+
+ # to enforce control rate
+ self.rate = Rate(control_freq)
+ self.rate_measure = RateMeasure(name="robot", freq_threshold=round(0.95 * control_freq))
+ self.timers = Timers(history=100, disable_on_creation=False)
+
+ assert (action_scale is not None), "must provide action scaling bounds"
+ assert len(action_scale) == 7, "must provide scaling for all dimensions"
+ self.action_scale = np.array(action_scale).reshape(-1)
+
+ camera_names_to_sizes = deepcopy(camera_names_to_sizes)
+ if camera_names_to_sizes is None:
+ self.camera_names_to_sizes = {}
+ else:
+ self.camera_names_to_sizes = camera_names_to_sizes
+
+ # save kwargs for serialization
+ kwargs = dict(
+ camera_names_to_sizes=camera_names_to_sizes,
+ action_scale=action_scale,
+ init_ros_node=init_ros_node,
+ publish_target_pose=publish_target_pose,
+ fake_controller=fake_controller,
+ use_moveit=use_moveit,
+ control_freq=control_freq
+ )
+ self._init_kwargs = deepcopy(kwargs)
+
+ # connect to robot
+ # if (sys.version_info > (3, 0)):
+ # from RobotTeleop.robots.panda_redis_interface import PandaRedisInterface
+ # self.robot_interface = PandaRedisInterface(
+ # init_ros_node=init_ros_node,
+ # publish_target_pose=publish_target_pose,
+ # fake_controller=fake_controller,
+ # use_moveit=use_moveit,
+ # camera_names_to_sizes=camera_names_to_sizes,
+ # debug_times=True,
+ # )
+ # else:
+ from RobotTeleop.robots.panda_ros_interface import PandaRosInterface
+ self.robot_interface = PandaRosInterface(
+ init_ros_node=init_ros_node,
+ publish_target_pose=publish_target_pose,
+ fake_controller=fake_controller,
+ use_moveit=use_moveit,
+ camera_names_to_sizes=camera_names_to_sizes,
+ #use_redis=True,
+ )
+
+ # IMPORTANT: initialize JIT functions that may need to compile
+ self._compile_jit_functions()
+
+ # last grasp action - initialize to false, since gripper should start open
+ self.did_grasp = False
+
+ def _compile_jit_functions(self):
+ """
+ Helper function to incur the cost of compiling jit functions used by this class upfront.
+
+ NOTE: this function looks strange because we apparently need to make it look like the env.step function
+ for it to compile properly, otherwise we will have a heavy delay on the first env.step call...
+
+ TODO: figure out why this needs to look like the step function code below...
+ """
+
+ # current robot state to use as reference
+ ee_pos, ee_quat = self.robot_interface.ee_pose
+ ee_mat = U.quat2mat(ee_quat)
+ ee_quat_hat = U.mat2quat(ee_mat)
+
+ # convert delta axis-angle to delta rotation matrix, and from there, to absolute target rotation
+ drot = np.array([0., 0., 0.05])
+ angle = np.linalg.norm(drot)
+ if U.isclose(angle, 0.):
+ drot_quat = np.array([0., 0., 0., 1.])
+ else:
+ axis = drot / angle
+ drot_quat = U.axisangle2quat(axis, angle)
+
+ # get target rotation
+ drot_mat = U.quat2mat(drot_quat)
+ target_rot_mat = (drot_mat.T).dot(ee_mat)
+ target_rot_quat = U.mat2quat(target_rot_mat)
+
+ def step(self, action, need_obs=True):
+ """
+ Step in the environment with an action.
+
+ Args:
+ action (np.array): action to take, should be in [-1, 1]
+ need_obs (bool): if False, don't return the observation, because this
+ can involve copying image data around. This allows for more
+ flexibility on when observations are retrieved.
+
+ Returns:
+ observation (dict): new observation dictionary
+ reward (float): reward for this step
+ done (bool): whether the task is done
+ info (dict): extra information
+ """
+ assert len(action.shape) == 1 and action.shape[0] == 7, "action has incorrect dimensions"
+ assert np.min(action) >= -1. and np.max(action) <= 1., "incorrect action bounds"
+
+ # rate-limiting
+ self.rate.sleep()
+ self.rate_measure.measure()
+
+ self.timers.tic("real_panda_step")
+
+ # unscale action
+ action = self.action_scale * action
+
+ # extract action components
+ dpos = action[:3]
+ drot = action[3:6]
+ gripper_command = action[6:7]
+
+ # current robot state to use as reference
+ ee_pos, ee_quat = self.robot_interface.ee_pose
+ ee_mat = U.quat2mat(ee_quat)
+
+ # absolute target position
+ target_pos = ee_pos + dpos
+
+ # convert delta axis-angle to delta rotation matrix, and from there, to absolute target rotation
+ angle = np.linalg.norm(drot)
+ if U.isclose(angle, 0.):
+ drot_quat = np.array([0., 0., 0., 1.])
+ else:
+ axis = drot / angle
+ drot_quat = U.axisangle2quat(axis, angle)
+ drot_mat = U.quat2mat(drot_quat)
+ target_rot_mat = (drot_mat.T).dot(ee_mat)
+ target_rot_quat = U.mat2quat(target_rot_mat)
+
+ # play end effector action
+ self.robot_interface.move_to_ee_pose(pos=target_pos, ori=target_rot_quat)
+
+ # convert continuous control signal in [-1, 1] to boolean
+ should_close = (float(gripper_command) < 0.)
+
+ # only send command if trying to change gripper state.
+ # this is due to hardware limitations - robot grippers suck.
+ if should_close != self.did_grasp:
+ if should_close:
+ self.robot_interface.gripper_close()
+ else:
+ self.robot_interface.gripper_open()
+
+ # remember last grasp command
+ self.did_grasp = should_close
+
+ # get observation
+ obs = None
+ if need_obs:
+ obs = self.get_observation()
+ r = self.get_reward()
+ done = self.is_done()
+
+ self.timers.toc("real_panda_step")
+
+ return obs, r, done, {}
+
+ def reset(self):
+ """
+ Reset environment.
+
+ Returns:
+ observation (dict): initial observation dictionary.
+ """
+ self.robot_interface.gripper_open()
+ self.robot_interface.reset_teleop()
+ self.rate_measure = RateMeasure(name="robot", freq_threshold=round(0.95 * self.control_freq))
+
+ return self.get_observation()
+
+ def reset_to(self, state):
+ """
+ Reset to a specific state. On real robot, we visualize the start image,
+ and a human should manually reset the scene.
+
+ Reset to a specific simulator state.
+
+ Args:
+ state (dict): initial state that contains:
+ - image (np.ndarray): initial workspace image
+
+ Returns:
+ None
+ """
+ assert "front_image" in state
+ ref_img = cv2.cvtColor(state["front_image"], cv2.COLOR_RGB2BGR)
+
+ print("\n" + "*" * 50)
+ print("Reset environment to image shown in left pane")
+ print("Press 'c' when ready to continue.")
+ print("*" * 50 + "\n")
+ while(True):
+ # read current image
+ cur_img = self.robot_interface.get_camera_frame(camera_name="front_image")
+ cur_img = cv2.cvtColor(cur_img, cv2.COLOR_RGB2BGR)
+
+ # concatenate frames to display
+ img = np.concatenate([ref_img, cur_img], axis=1)
+
+ # display frame
+ cv2.imshow('initial state alignment window', img)
+ if cv2.waitKey(1) & 0xFF == ord('c'):
+ cv2.destroyAllWindows()
+ break
+
+ def render(self, mode="human", height=None, width=None, camera_name=None, **kwargs):
+ """
+ Render from simulation to either an on-screen window or off-screen to RGB array.
+
+ Args:
+ mode (str): pass "human" for on-screen rendering or "rgb_array" for off-screen rendering
+ height (int): height of image to render - only used if mode is "rgb_array"
+ width (int): width of image to render - only used if mode is "rgb_array"
+ """
+ if mode =="human":
+ raise Exception("on-screen rendering not supported currently")
+ if mode == "rgb_array":
+ # assert (height is None) and (width is None), "cannot resize images"
+ assert camera_name in self.camera_names_to_sizes, "invalid camera name"
+ return self.robot_interface.get_camera_frame(camera_name=camera_name)
+ else:
+ raise NotImplementedError("mode={} is not implemented".format(mode))
+
+ def get_observation(self, obs=None):
+ """
+ Get current environment observation dictionary.
+
+ Args:
+ ob (np.array): current observation dictionary.
+ """
+ self.timers.tic("get_observation")
+ observation = {}
+ observation["ee_pose"] = np.concatenate(self.robot_interface.ee_pose)
+ observation["joint_positions"] = self.robot_interface.joint_position
+ observation["joint_velocities"] = self.robot_interface.joint_velocity
+ observation["gripper_position"] = self.robot_interface.gripper_position
+ observation["gripper_velocity"] = self.robot_interface.gripper_velocity
+ for cam_name in self.camera_names_to_sizes:
+ im = self.robot_interface.get_camera_frame(camera_name=cam_name)
+ if self.postprocess_visual_obs:
+ im = ObsUtils.process_image(im)
+ observation[cam_name] = im
+ self.timers.toc("get_observation")
+ return observation
+
+ def get_state(self):
+ """
+ Get current environment simulator state as a dictionary. Should be compatible with @reset_to.
+ """
+ return dict(states=np.zeros(1))
+ # raise Exception("Real robot has no simulation state.")
+
+ def get_reward(self):
+ """
+ Get current reward.
+ """
+ return 0.
+
+ def get_goal(self):
+ """
+ Get goal observation. Not all environments support this.
+ """
+ raise NotImplementedError
+
+ def set_goal(self, **kwargs):
+ """
+ Set goal observation with external specification. Not all environments support this.
+ """
+ raise NotImplementedError
+
+ def is_done(self):
+ """
+ Check if the task is done (not necessarily successful).
+ """
+ return False
+
+ def is_success(self):
+ """
+ Check if the task condition(s) is reached. Should return a dictionary
+ { str: bool } with at least a "task" key for the overall task success,
+ and additional optional keys corresponding to other task criteria.
+ """
+
+ # real robot environments don't usually have a success check - this must be done manually
+ return { "task" : False }
+
+ @property
+ def action_dimension(self):
+ """
+ Returns dimension of actions (int).
+ """
+ return 7
+
+ @property
+ def name(self):
+ """
+ Returns name of environment name (str).
+ """
+ # return self._env_name
+
+ # for real robot. ensure class name is stored in env meta (as env name) for use with any external
+ # class registries
+ return self.__class__.__name__
+
+ @property
+ def type(self):
+ """
+ Returns environment type (int) for this kind of environment.
+ This helps identify this env class.
+ """
+ return EB.EnvType.REAL_TYPE
+
+ def serialize(self):
+ """
+ Save all information needed to re-instantiate this environment in a dictionary.
+ This is the same as @env_meta - environment metadata stored in hdf5 datasets,
+ and used in utils/env_utils.py.
+ """
+ return dict(env_name=self.name, type=self.type, env_kwargs=deepcopy(self._init_kwargs))
+
+ @classmethod
+ def create_for_data_processing(cls, env_name, camera_names, camera_height, camera_width, reward_shaping, **kwargs):
+ """
+ Create environment for processing datasets, which includes extracting
+ observations, labeling dense / sparse rewards, and annotating dones in
+ transitions. For gym environments, input arguments (other than @env_name)
+ are ignored, since environments are mostly pre-configured.
+
+ Args:
+ env_name (str): name of gym environment to create
+
+ Returns:
+ env (EnvRealPanda instance)
+ """
+
+ # initialize obs utils so it knows which modalities are image modalities
+ assert "camera_names_to_sizes" in kwargs
+ image_modalities = list(kwargs["camera_names_to_sizes"].keys())
+ obs_modality_specs = {
+ "obs": {
+ "low_dim": [], # technically unused, so we don't have to specify all of them
+ "image": image_modalities,
+ }
+ }
+ ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs)
+
+ # note that @postprocess_visual_obs is False since this env's images will be written to a dataset
+ return cls(
+ env_name=env_name,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=True,
+ postprocess_visual_obs=False,
+ **kwargs,
+ )
+
+ @property
+ def rollout_exceptions(self):
+ """
+ Return tuple of exceptions to except when doing rollouts. This is useful to ensure
+ that the entire training run doesn't crash because of a bad policy that causes unstable
+ simulation computations.
+ """
+ return ()
+
+ @property
+ def base_env(self):
+ """
+ Grabs base simulation environment.
+ """
+ # we don't wrap any env
+ return self
+
+ def __repr__(self):
+ """
+ Pretty-print env description.
+ """
+ return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4)
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/envs/env_real_panda_gprs.py b/phantom/submodules/phantom-robomimic/robomimic/envs/env_real_panda_gprs.py
new file mode 100644
index 0000000000000000000000000000000000000000..f931074dd0d7f9bb45b706373cd0559d81cd1e53
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/envs/env_real_panda_gprs.py
@@ -0,0 +1,732 @@
+"""
+Real robot env wrapper for Yifeng's GPRS control stack.
+"""
+import os
+import time
+import json
+import sys
+import numpy as np
+from copy import deepcopy
+from easydict import EasyDict as edict
+
+import cv2
+from PIL import Image
+
+import RobotTeleop
+import RobotTeleop.utils as U
+from RobotTeleop.utils import Rate, RateMeasure, Timers
+
+try:
+ # GPRS imports
+ from gprs.franka_interface import FrankaInterface
+ from gprs.camera_redis_interface import CameraRedisSubInterface
+ from gprs.utils import YamlConfig
+ from gprs import config_root
+
+ from rpl_vision_utils.utils import img_utils as ImgUtils
+except ImportError:
+ print("WARNING: no GPRS...")
+
+import robomimic.envs.env_base as EB
+import robomimic.utils.obs_utils as ObsUtils
+from robomimic.utils.log_utils import log_warning
+
+try:
+ import robosuite.utils.transform_utils as T
+except ImportError:
+ print("WARNING: could not import robosuite transform utils (needed for using absolute actions with GPRS")
+
+
+def center_crop(im, t_h, t_w):
+ assert(im.shape[-3] >= t_h and im.shape[-2] >= t_w)
+ assert(im.shape[-1] in [1, 3])
+ crop_h = int((im.shape[-3] - t_h) / 2)
+ crop_w = int((im.shape[-2] - t_w) / 2)
+ return im[..., crop_h:crop_h + t_h, crop_w:crop_w + t_w, :]
+
+
+def get_depth_scale(camera_name):
+ """
+ Returns scaling factor that converts from uint16 depth to real-valued depth (in meters).
+ """
+
+ # TODO: fix duplication
+ if camera_name == "front":
+ return 0.0010000000474974513
+ if camera_name == "wrist":
+ return 0.0010000000474974513
+ raise Exception("should not reach here")
+ # from RobotTeleop.scripts.debug_april_tag import get_depth_scale_unified
+ # return get_depth_scale_unified(camera_name=camera_name)
+
+
+class EnvRealPandaGPRS(EB.EnvBase):
+ """Wrapper class for real panda environment"""
+ def __init__(
+ self,
+ env_name,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=True,
+ postprocess_visual_obs=True,
+ control_freq=20.,
+ camera_names_to_sizes=None,
+ center_crop_images=True,
+ general_cfg_file=None,
+ controller_type=None,
+ controller_cfg_file=None,
+ controller_cfg_dict=None,
+ use_depth_obs=False,
+ absolute_actions=False, # use absolute pos and rot (axis-angle) in 7-dim action vector
+ # additional GPRS-specific args
+ state_freq=100.,
+ control_timeout=1.0,
+ has_gripper=True,
+ use_visualizer=False,
+ debug=False,
+ ):
+ """
+ Args:
+ env_name (str): name of environment.
+
+ render (bool): ignored - on-screen rendering is not supported
+
+ render_offscreen (bool): ignored - image observations are supplied by default
+
+ use_image_obs (bool): ignored - image observations are used by default.
+
+ postprocess_visual_obs (bool): if True, postprocess image observations
+ to prepare for learning. This should only be False when extracting observations
+ for saving to a dataset (to save space on RGB images for example).
+
+ control_freq (int): real-world control frequency to try and enforce through rate-limiting
+
+ camera_names_to_sizes (dict): dictionary that maps camera names to tuple of image height and width
+ to return
+ """
+ self._env_name = env_name
+ self.postprocess_visual_obs = postprocess_visual_obs
+ self.control_freq = control_freq
+ self.absolute_actions = absolute_actions
+ self.general_cfg_file = general_cfg_file
+ self.controller_type = controller_type
+ self.controller_cfg_file = controller_cfg_file
+ self.controller_cfg_dict = deepcopy(controller_cfg_dict) if controller_cfg_dict is not None else None
+ if self.controller_cfg_dict is not None:
+ # control code expects easydict
+ self.controller_cfg = edict(self.controller_cfg_dict)
+ else:
+ assert controller_cfg_file is not None
+ self.controller_cfg = YamlConfig(os.path.join(config_root, controller_cfg_file)).as_easydict()
+ self.use_depth_obs = use_depth_obs
+
+ # to enforce control rate
+ self.rate = Rate(control_freq)
+ self.rate_measure = RateMeasure(name="robot", freq_threshold=round(0.95 * control_freq))
+ self.timers = Timers(history=100, disable_on_creation=False)
+
+ camera_names_to_sizes = deepcopy(camera_names_to_sizes)
+ if camera_names_to_sizes is None:
+ self.camera_names_to_sizes = {}
+ else:
+ self.camera_names_to_sizes = camera_names_to_sizes
+ self.center_crop_images = center_crop_images
+
+ self._exclude_depth_from_obs = (not self.use_depth_obs)
+ if self.use_depth_obs and self.postprocess_visual_obs:
+ for cam_name in self.camera_names_to_sizes:
+ depth_mod = "{}_depth".format(cam_name)
+ if not ((depth_mod in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality(key=depth_mod, obs_modality="depth")):
+ log_warning("depth observation {} will not be postprocessed since robomimic is not aware of it".format(depth_mod))
+ # # HACK: assume this means we don't actually need depth, but we might the camera interface to support it for TAMP / perception
+ # self.use_depth_obs = False
+ self._exclude_depth_from_obs = True
+
+ # save kwargs for serialization
+ kwargs = dict(
+ env_name=env_name,
+ camera_names_to_sizes=camera_names_to_sizes,
+ center_crop_images=center_crop_images,
+ general_cfg_file=general_cfg_file,
+ control_freq=control_freq,
+ controller_type=controller_type,
+ controller_cfg_file=controller_cfg_file,
+ controller_cfg_dict=controller_cfg_dict,
+ use_depth_obs=use_depth_obs,
+ state_freq=state_freq,
+ control_timeout=control_timeout,
+ has_gripper=has_gripper,
+ use_visualizer=use_visualizer,
+ debug=debug,
+ )
+ self._init_kwargs = deepcopy(kwargs)
+
+ # connect to robot
+ self.robot_interface = FrankaInterface(
+ general_cfg_file=os.path.join(config_root, general_cfg_file),
+ control_freq=control_freq,
+ state_freq=state_freq,
+ control_timeout=control_timeout,
+ has_gripper=has_gripper,
+ use_visualizer=use_visualizer,
+ debug=debug,
+ )
+
+ # TODO: clean up camera ID definition later
+
+ # start camera interfaces
+ camera_ids = list(range(len(self.camera_names_to_sizes)))
+ self.cr_interfaces = {}
+ for c_id, c_name in enumerate(self.camera_names_to_sizes):
+ cr_interface = CameraRedisSubInterface(camera_id=c_id, use_depth=self.use_depth_obs)
+ cr_interface.start()
+ self.cr_interfaces[c_name] = cr_interface
+
+ # IMPORTANT: initialize JIT functions that may need to compile
+ self._compile_jit_functions()
+
+ def _compile_jit_functions(self):
+ """
+ Helper function to incur the cost of compiling jit functions used by this class upfront.
+
+ NOTE: this function looks strange because we apparently need to make it look like the env.step function
+ for it to compile properly, otherwise we will have a heavy delay on the first env.step call...
+
+ TODO: figure out why this needs to look like the step function code below...
+ """
+
+ # current robot state to use as reference
+ # ee_pos, ee_quat = self.robot_interface.ee_pose
+ ee_mat = U.quat2mat(np.array([0., 0., 0., 1.]))
+ ee_quat_hat = U.mat2quat(ee_mat)
+
+ # convert delta axis-angle to delta rotation matrix, and from there, to absolute target rotation
+ drot = np.array([0., 0., 0.05])
+ angle = np.linalg.norm(drot)
+ if U.isclose(angle, 0.):
+ drot_quat = np.array([0., 0., 0., 1.])
+ else:
+ axis = drot / angle
+ drot_quat = U.axisangle2quat(axis, angle)
+
+ # get target rotation
+ drot_mat = U.quat2mat(drot_quat)
+ target_rot_mat = (drot_mat.T).dot(ee_mat)
+ target_rot_quat = U.mat2quat(target_rot_mat)
+
+ if self.absolute_actions:
+ test_mat = T.quat2mat(T.axisangle2quat(drot))
+
+ def _get_unified_getter(self):
+ """
+ For HITL-TAMP teleoperation only - provides access to important information for perception.
+ """
+ from htamp.scripts.test_real_world import UnifiedGetter
+ return UnifiedGetter(
+ use_real_robot=True,
+ robot_interface=self.robot_interface,
+ camera_interface=self.cr_interfaces["front_image"],
+ )
+
+ def switch_controllers(self, controller_dict):
+ """
+ Switch the controller type and controller config being used. Useful
+ for switching inbetween two different kinds of controllers during an
+ episode - for example, OSC and Joint Impedance.
+
+ Args:
+ controller_dict (dict): dictionary that contains two keys
+ type (str): type of controller
+ cfg (easydict): controller config
+
+ Returns:
+ old_controller_dict (dict): the previous @controller_dict
+ """
+ old_controller_dict = dict(type=self.controller_type, cfg=deepcopy(self.controller_cfg))
+ print("*" * 50)
+ print("SWITCH TO CONTROLLER TYPE: {}".format(controller_dict["type"]))
+ print("*" * 50)
+ self.controller_type = controller_dict["type"]
+ self.controller_cfg = controller_dict["cfg"]
+ return old_controller_dict
+
+ def step(self, action, need_obs=True):
+ """
+ Step in the environment with an action.
+
+ Args:
+ action (np.array): action to take, should be in [-1, 1]
+ need_obs (bool): if False, don't return the observation, because this
+ can involve copying image data around. This allows for more
+ flexibility on when observations are retrieved.
+
+ Returns:
+ observation (dict): new observation dictionary
+ reward (float): reward for this step
+ done (bool): whether the task is done
+ info (dict): extra information
+ """
+ # print("step got action: {}".format(action))
+ if self.controller_type == "OSC_POSE":
+ assert len(action.shape) == 1 and action.shape[0] == 7, "action has incorrect dimensions"
+
+ if self.absolute_actions:
+ # convert action from absolute to relative for compatibility with rest of code
+ action = np.array(action)
+
+ # absolute pose target
+ target_pos = action[:3]
+ target_rot = T.quat2mat(T.axisangle2quat(action[3:6]))
+
+ # current pose
+ last_robot_state = self.robot_interface._state_buffer[-1]
+ ee_pose = np.array(last_robot_state.O_T_EE).reshape((4, 4)).T
+ start_pos = ee_pose[:3, 3]
+ start_rot = ee_pose[:3, :3]
+
+ # TODO: remove hardcode
+ max_dpos = np.array([0.08, 0.08, 0.08])
+ max_drot = np.array([0.5, 0.5, 0.5])
+
+ # copied from MG class (TODO: unify)
+ delta_position = target_pos - start_pos
+ delta_position = np.clip(delta_position / max_dpos, -1., 1.)
+
+ delta_rot_mat = target_rot.dot(start_rot.T)
+ delta_rot_quat = U.mat2quat(delta_rot_mat)
+ delta_rot_aa = U.quat2axisangle(delta_rot_quat)
+ delta_rotation = delta_rot_aa[0] * delta_rot_aa[1]
+ delta_rotation = np.clip(delta_rotation / max_drot, -1., 1.)
+
+ # relative action
+ action[:3] = delta_position
+ action[3:6] = delta_rotation
+ action[6:] = np.clip(action[6:], -1., 1.)
+
+ assert np.min(action) >= -1. and np.max(action) <= 1., "incorrect action bounds"
+ elif self.controller_type == "JOINT_IMPEDANCE":
+ assert len(action.shape) == 1 and action.shape[0] == 8, "action has incorrect dimensions"
+ assert not self.absolute_actions
+ if not np.any(action[:7]):
+ raise Exception("GOT ZERO ACTION WITH JOINT IMPEDANCE CONTROLLER - TERMINATING")
+
+ # compare current joint position with issued action
+ last_robot_state = self.robot_interface._state_buffer[-1]
+ cur_q = np.array(last_robot_state.q)
+
+ # print("joint action: {}".format(action[:7]))
+ # print("current joints: {}".format(cur_q))
+ # print("absolute error: {}".format(np.abs(action[:7] - cur_q)))
+ # print("max absolute error: {}".format(np.max(np.abs(action[:7] - cur_q))))
+
+ # if np.max(np.abs(action[:7] - cur_q)) > 0.2:
+ # raise Exception("max absolute error too high - stopping")
+
+ # TODO: joint impedance controller takes in raw joint positions - we might need to change this later, if we want to learn from these actions
+ # assert np.min(action) >= -1. and np.max(action) <= 1., "incorrect action bounds"
+
+ # meaure rate-limiting
+ # self.rate.sleep()
+ self.rate_measure.measure()
+
+ self.timers.tic("real_panda_step")
+
+ self.robot_interface.control(
+ control_type=self.controller_type,
+ action=action,
+ controller_cfg=self.controller_cfg,
+ )
+
+ # remember the last gripper action taken in this variable
+ gripper_command = action[-1:]
+ self.did_grasp = (gripper_command[0] > 0.)
+
+ # get observation
+ obs = None
+ if need_obs:
+ obs = self.get_observation()
+ r = self.get_reward()
+ done = self.is_done()
+
+ self.timers.toc("real_panda_step")
+
+ return obs, r, done, {}
+
+ def reset(self):
+ """
+ Reset environment.
+
+ Returns:
+ observation (dict): initial observation dictionary.
+ """
+
+ # self.robot_interface.close()
+ # del self.robot_interface
+ # self.robot_interface = FrankaInterface(
+ # general_cfg_file=os.path.join(config_root, self._init_kwargs['general_cfg_file']),
+ # control_freq=self._init_kwargs['control_freq'],
+ # state_freq=self._init_kwargs['state_freq'],
+ # control_timeout=self._init_kwargs['control_timeout'],
+ # has_gripper=self._init_kwargs['has_gripper'],
+ # use_visualizer=self._init_kwargs['use_visualizer'],
+ # debug=self._init_kwargs['debug'],
+ # )
+
+ self.robot_interface.clear_buffer()
+
+ print("restarting the robot interface")
+
+ # Code below based on https://github.com/UT-Austin-RPL/robot_infra/blob/master/gprs/examples/reset_robot_joints.py
+
+ # Golden resetting joints
+ reset_joint_positions = [0.09162008114028396, -0.19826458111314524, -0.01990020486871322, -2.4732269941140346, -0.01307073642274261, 2.30396583422025, 0.8480939705504309]
+
+ # This is for varying initialization of joints a little bit to
+ # increase data variation.
+ # reset_joint_positions = [e + np.clip(np.random.randn() * 0.005, -0.005, 0.005) for e in reset_joint_positions]
+ action = reset_joint_positions + [-1.]
+
+ # temp robot interface to use for joint position control
+ # tmp_robot_interface = FrankaInterface(os.path.join(config_root, self.general_cfg_file), use_visualizer=False)
+ # tmp_controller_cfg = YamlConfig(os.path.join(config_root, self.controller_cfg_file)).as_easydict()
+ tmp_controller_cfg = deepcopy(self.controller_cfg)
+
+ while True:
+ if len(self.robot_interface._state_buffer) > 0:
+ # print(self.robot_interface._state_buffer[-1].q)
+ # print(reset_joint_positions)
+ # print(np.max(np.abs(np.array(self.robot_interface._state_buffer[-1].q) - np.array(reset_joint_positions))))
+ # print("-----------------------")
+
+ # if np.max(np.abs(np.array(self.robot_interface._state_buffer[-1].q) - np.array(reset_joint_positions))) < 1e-3:
+ if np.max(np.abs(np.array(self.robot_interface._state_buffer[-1].q) - np.array(reset_joint_positions))) < 1e-2:
+ break
+
+ self.robot_interface.control(
+ control_type="JOINT_POSITION",
+ action=action,
+ controller_cfg=tmp_controller_cfg,
+ )
+
+ # tmp_robot_interface.close()
+
+ # We added this sleep here to give the C++ controller time to reset from joint control mode to no control mode
+ # to prevent some issues.
+ time.sleep(1.0)
+ print("RESET DONE")
+
+ self.did_grasp = False
+
+ return self.get_observation()
+
+ def reset_to(self, state):
+ """
+ Reset to a specific state. On real robot, we visualize the start image,
+ and a human should manually reset the scene.
+
+ Reset to a specific simulator state.
+
+ Args:
+ state (dict): initial state that contains:
+ - image (np.ndarray): initial workspace image
+
+ Returns:
+ None
+ """
+ assert "front_image" in state
+ ref_img = cv2.cvtColor(state["front_image"], cv2.COLOR_RGB2BGR)
+
+ print("\n" + "*" * 50)
+ print("Reset environment to image shown in left pane")
+ print("Press 'c' when ready to continue.")
+ print("*" * 50 + "\n")
+ while(True):
+ # read current image
+ cur_img = self._get_image(camera_name="front_image")
+ if self.use_depth_obs:
+ cur_img = cur_img[0]
+
+ # concatenate frames to display
+ img = np.concatenate([ref_img, cur_img], axis=1)
+
+ # display frame
+ cv2.imshow('initial state alignment window', img)
+ if cv2.waitKey(1) & 0xFF == ord('c'):
+ cv2.destroyAllWindows()
+ break
+
+ def render(self, mode="human", height=None, width=None, camera_name=None, **kwargs):
+ """
+ Render from simulation to either an on-screen window or off-screen to RGB array.
+
+ Args:
+ mode (str): pass "human" for on-screen rendering or "rgb_array" for off-screen rendering
+ height (int): height of image to render - only used if mode is "rgb_array"
+ width (int): width of image to render - only used if mode is "rgb_array"
+ """
+ if mode =="human":
+ raise Exception("on-screen rendering not supported currently")
+ if mode == "rgb_array":
+ # assert (height is None) and (width is None), "cannot resize images"
+ assert camera_name in self.camera_names_to_sizes, "invalid camera name"
+ imgs = self.cr_interfaces[camera_name].get_img()
+ return imgs["color"][..., ::-1]
+ # return self._get_image(camera_name=camera_name)[..., ::-1]
+ else:
+ raise NotImplementedError("mode={} is not implemented".format(mode))
+
+ def get_observation(self, obs=None):
+ """
+ Get current environment observation dictionary.
+
+ Args:
+ ob (np.array): current observation dictionary.
+ """
+ self.timers.tic("get_observation")
+ observation = {}
+ last_robot_state = self.robot_interface._state_buffer[-1]
+ last_gripper_state = self.robot_interface._gripper_state_buffer[-1]
+ ee_pose = np.array(last_robot_state.O_T_EE).reshape((4, 4)).T
+ if np.count_nonzero(ee_pose.reshape(-1)) == 0:
+ raise Exception("GOT ZERO EE POSE")
+ ee_pos = ee_pose[:3, 3]
+ ee_quat = U.mat2quat(ee_pose[:3, :3])
+ observation["ee_pose"] = np.concatenate([ee_pos, ee_quat])
+ observation["joint_positions"] = np.array(last_robot_state.q)
+ observation["joint_velocities"] = np.array(last_robot_state.dq)
+ observation["gripper_position"] = np.array(last_gripper_state.width)
+ # observation["gripper_velocity"] = self.robot_interface.gripper_velocity
+ for cam_name in self.camera_names_to_sizes:
+ im = self._get_image(camera_name=cam_name)
+ if self.use_depth_obs:
+ im, depth_im = im
+ # im, depth_im, depth_im_unaligned = im
+ # observation[cam_name + "_depth"] = depth_im
+ # observation[cam_name + "_unaligned_depth"] = depth_im_unaligned
+ if (not self._exclude_depth_from_obs):
+ depth_im_mod = cam_name + "_depth"
+ if self.postprocess_visual_obs and (depth_im_mod in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality(key=depth_im_mod, obs_modality="depth"):
+ depth_im = ObsUtils.process_obs(obs=depth_im, obs_key=depth_im_mod)
+ observation[depth_im_mod] = depth_im
+ im = im[..., ::-1]
+ if self.postprocess_visual_obs:
+ # NOTE: commented out for now, since run-trained-agent was running into issues with unneeded agent modalities that were present in @self.camera_names_to_sizes
+ # assert (cam_name in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality(key=cam_name, obs_modality="rgb")
+ im = ObsUtils.process_obs(obs=im, obs_key=cam_name)
+ observation[cam_name] = im
+ self.timers.toc("get_observation")
+ return observation
+
+ def _get_image(self, camera_name):
+ """
+ Get image from camera interface
+ """
+
+ # get image
+ imgs = self.cr_interfaces[camera_name].get_img()
+ im = imgs["color"]
+
+ # resize image
+ im_size = self.camera_names_to_sizes[camera_name]
+ if im_size is not None:
+ im = Image.fromarray(im).resize((im_size[1], im_size[0]), Image.BILINEAR)
+ im = np.array(im).astype(np.uint8)
+
+ if self.center_crop_images:
+ # center crop image
+ crop_size = min(im.shape[:2])
+ im = center_crop(im, crop_size, crop_size)
+
+ if self.use_depth_obs:
+ depth_im = imgs["depth"]
+ if im_size is not None:
+ # depth_im = Image.fromarray(depth_im).resize((im_size[1], im_size[0]), Image.BILINEAR)
+ depth_im = Image.fromarray(depth_im).resize((im_size[1], im_size[0]))
+ # note: depth images are uint16, with default scale 0.001m
+ depth_im = np.array(depth_im).astype(np.uint16)
+ if len(depth_im.shape) < 3:
+ depth_im = depth_im[..., None] # add channel dimension
+ if self.center_crop_images:
+ depth_im = center_crop(depth_im, crop_size, crop_size)
+ return im, depth_im
+ # depth_images = []
+ # for k in ["depth", "unaligned_depth"]:
+ # depth_im = imgs[k]
+ # if im_size is not None:
+ # # depth_im = Image.fromarray(depth_im).resize((im_size[1], im_size[0]), Image.BILINEAR)
+ # depth_im = Image.fromarray(depth_im).resize((im_size[1], im_size[0]))
+ # # note: depth images are uint16, with default scale 0.001m
+ # depth_im = np.array(depth_im).astype(np.uint16)
+ # if len(depth_im.shape) < 3:
+ # depth_im = depth_im[..., None] # add channel dimension
+ # if self.center_crop_images:
+ # depth_im = center_crop(depth_im, crop_size, crop_size)
+ # depth_images.append(depth_im)
+ # return im, depth_images[0], depth_images[1]
+ return im
+
+ def get_state(self):
+ """
+ Get current environment simulator state as a dictionary. Should be compatible with @reset_to.
+ """
+ return dict(states=np.zeros(1))
+ # raise Exception("Real robot has no simulation state.")
+
+ def get_reward(self):
+ """
+ Get current reward.
+ """
+ return 0.
+
+ def get_goal(self):
+ """
+ Get goal observation. Not all environments support this.
+ """
+ raise NotImplementedError
+
+ def set_goal(self, **kwargs):
+ """
+ Set goal observation with external specification. Not all environments support this.
+ """
+ raise NotImplementedError
+
+ def is_done(self):
+ """
+ Check if the task is done (not necessarily successful).
+ """
+ return False
+
+ def is_success(self):
+ """
+ Check if the task condition(s) is reached. Should return a dictionary
+ { str: bool } with at least a "task" key for the overall task success,
+ and additional optional keys corresponding to other task criteria.
+ """
+
+ # real robot environments don't usually have a success check - this must be done manually
+ return { "task" : False }
+
+ @property
+ def action_dimension(self):
+ """
+ Returns dimension of actions (int).
+ """
+ if self.controller_type == "OSC_POSE":
+ return 7
+ elif self.controller_type == "JOINT_IMPEDANCE":
+ return 8
+ assert False, "should never get here"
+
+ @property
+ def action_dim(self):
+ """
+ Returns dimension of actions (int).
+ """
+ return self.action_dimension
+
+ @property
+ def name(self):
+ """
+ Returns name of environment name (str).
+ """
+ # return self._env_name
+
+ # for real robot. ensure class name is stored in env meta (as env name) for use with any external
+ # class registries
+ return self.__class__.__name__
+
+ @property
+ def type(self):
+ """
+ Returns environment type (int) for this kind of environment.
+ This helps identify this env class.
+ """
+ return EB.EnvType.GPRS_REAL_TYPE
+
+ def serialize(self):
+ """
+ Save all information needed to re-instantiate this environment in a dictionary.
+ This is the same as @env_meta - environment metadata stored in hdf5 datasets,
+ and used in utils/env_utils.py.
+ """
+ return dict(env_name=self.name, type=self.type, env_kwargs=deepcopy(self._init_kwargs))
+
+ @classmethod
+ def create_for_data_processing(
+ cls,
+ env_name,
+ camera_names,
+ camera_height,
+ camera_width,
+ reward_shaping,
+ render=None,
+ render_offscreen=None,
+ use_image_obs=None,
+ use_depth_obs=None,
+ **kwargs,
+ ):
+ """
+ Create environment for processing datasets, which includes extracting
+ observations, labeling dense / sparse rewards, and annotating dones in
+ transitions. For gym environments, input arguments (other than @env_name)
+ are ignored, since environments are mostly pre-configured.
+
+ Args:
+ env_name (str): name of gym environment to create
+
+ Returns:
+ env (EnvRealPanda instance)
+ """
+
+ # initialize obs utils so it knows which modalities are image modalities
+ assert "camera_names_to_sizes" in kwargs
+ image_modalities = list(kwargs["camera_names_to_sizes"].keys())
+ obs_modality_specs = {
+ "obs": {
+ "low_dim": [], # technically unused, so we don't have to specify all of them
+ "image": image_modalities,
+ }
+ }
+ ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs)
+
+ # note that @postprocess_visual_obs is False since this env's images will be written to a dataset
+ return cls(
+ env_name=env_name,
+ render=False,
+ render_offscreen=True,
+ use_image_obs=True,
+ use_depth_obs=use_depth_obs if use_depth_obs is not None else False,
+ postprocess_visual_obs=False,
+ **kwargs,
+ )
+
+ @property
+ def rollout_exceptions(self):
+ """
+ Return tuple of exceptions to except when doing rollouts. This is useful to ensure
+ that the entire training run doesn't crash because of a bad policy that causes unstable
+ simulation computations.
+ """
+ return ()
+
+ @property
+ def base_env(self):
+ """
+ Grabs base simulation environment.
+ """
+ # we don't wrap any env
+ return self
+
+ def __repr__(self):
+ """
+ Pretty-print env description.
+ """
+ return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4)
+
+ def close(self):
+ """
+ Clean up env
+ """
+ for c_name in self.cr_interfaces:
+ self.cr_interfaces[c_name].stop()
+ self.robot_interface.close()
diff --git a/phantom/submodules/phantom-robomimic/robomimic/envs/env_robosuite.py b/phantom/submodules/phantom-robomimic/robomimic/envs/env_robosuite.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddd958012116e1700a27711f6af39af9dd0a7e29
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/envs/env_robosuite.py
@@ -0,0 +1,537 @@
+"""
+This file contains the robosuite environment wrapper that is used
+to provide a standardized environment API for training policies and interacting
+with metadata present in datasets.
+"""
+import json
+import os
+import numpy as np
+from copy import deepcopy
+
+import robosuite
+import robosuite.utils.transform_utils as T
+try:
+ # this is needed for ensuring robosuite can find the additional mimicgen environments (see https://mimicgen.github.io)
+ import mimicgen_envs
+except ImportError:
+ pass
+
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.envs.env_base as EB
+
+# protect against missing mujoco-py module, since robosuite might be using mujoco-py or DM backend
+try:
+ import mujoco_py
+ MUJOCO_EXCEPTIONS = [mujoco_py.builder.MujocoException]
+except ImportError:
+ MUJOCO_EXCEPTIONS = []
+
+
+class EnvRobosuite(EB.EnvBase):
+ """Wrapper class for robosuite environments (https://github.com/ARISE-Initiative/robosuite)"""
+ def __init__(
+ self,
+ env_name,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=False,
+ use_depth_obs=False,
+ postprocess_visual_obs=True,
+ **kwargs,
+ ):
+ """
+ Args:
+ env_name (str): name of environment. Only needs to be provided if making a different
+ environment from the one in @env_meta.
+
+ render (bool): if True, environment supports on-screen rendering
+
+ render_offscreen (bool): if True, environment supports off-screen rendering. This
+ is forced to be True if @env_meta["use_images"] is True.
+
+ use_image_obs (bool): if True, environment is expected to render rgb image observations
+ on every env.step call. Set this to False for efficiency reasons, if image
+ observations are not required.
+
+ use_depth_obs (bool): if True, environment is expected to render depth image observations
+ on every env.step call. Set this to False for efficiency reasons, if depth
+ observations are not required.
+
+ postprocess_visual_obs (bool): if True, postprocess image observations
+ to prepare for learning. This should only be False when extracting observations
+ for saving to a dataset (to save space on RGB images for example).
+ """
+ self.postprocess_visual_obs = postprocess_visual_obs
+ self.use_depth_obs = use_depth_obs
+
+ # robosuite version check
+ self._is_v1 = (robosuite.__version__.split(".")[0] == "1")
+ if self._is_v1:
+ assert (int(robosuite.__version__.split(".")[1]) >= 2), "only support robosuite v0.3 and v1.2+"
+
+ kwargs = deepcopy(kwargs)
+
+ # update kwargs based on passed arguments
+ update_kwargs = dict(
+ has_renderer=render,
+ has_offscreen_renderer=(render_offscreen or use_image_obs),
+ ignore_done=True,
+ use_object_obs=True,
+ use_camera_obs=use_image_obs,
+ camera_depths=use_depth_obs,
+ )
+ kwargs.update(update_kwargs)
+
+ if self._is_v1:
+ if kwargs["has_offscreen_renderer"]:
+ cuda_visible_device = os.environ.get("CUDA_VISIBLE_DEVICES", "")
+ if cuda_visible_device.isnumeric():
+ # assume that user specified a specific GPU ID
+ kwargs["render_gpu_device_id"] = int(cuda_visible_device)
+ else:
+ # ensure that we select the correct GPU device for rendering by testing for EGL rendering
+ # NOTE: this package should be installed from this link (https://github.com/StanfordVL/egl_probe)
+ import egl_probe
+ valid_gpu_devices = egl_probe.get_available_devices()
+ if len(valid_gpu_devices) > 0:
+ kwargs["render_gpu_device_id"] = valid_gpu_devices[0]
+ else:
+ # make sure gripper visualization is turned off (we almost always want this for learning)
+ kwargs["gripper_visualization"] = False
+ del kwargs["camera_depths"]
+ kwargs["camera_depth"] = use_depth_obs # rename kwarg
+
+ self._env_name = env_name
+ self._init_kwargs = deepcopy(kwargs)
+ self.env = robosuite.make(self._env_name, **kwargs)
+
+ if self._is_v1:
+ # Make sure joint position observations and eef vel observations are active
+ for ob_name in self.env.observation_names:
+ if ("joint_pos" in ob_name) or ("eef_vel" in ob_name):
+ self.env.modify_observable(observable_name=ob_name, attribute="active", modifier=True)
+
+ def step(self, action):
+ """
+ Step in the environment with an action.
+
+ Args:
+ action (np.array): action to take
+
+ Returns:
+ observation (dict): new observation dictionary
+ reward (float): reward for this step
+ done (bool): whether the task is done
+ info (dict): extra information
+ """
+ obs, r, done, info = self.env.step(action)
+ obs = self.get_observation(obs)
+ return obs, r, self.is_done(), info
+
+ def reset(self):
+ """
+ Reset environment.
+
+ Returns:
+ observation (dict): initial observation dictionary.
+ """
+ di = self.env.reset()
+ return self.get_observation(di)
+
+ def reset_to(self, state):
+ """
+ Reset to a specific simulator state.
+
+ Args:
+ state (dict): current simulator state that contains one or more of:
+ - states (np.ndarray): initial state of the mujoco environment
+ - model (str): mujoco scene xml
+
+ Returns:
+ observation (dict): observation dictionary after setting the simulator state (only
+ if "states" is in @state)
+ """
+ should_ret = False
+ if "model" in state:
+ self.reset()
+ robosuite_version_id = int(robosuite.__version__.split(".")[1])
+ if robosuite_version_id <= 3:
+ from robosuite.utils.mjcf_utils import postprocess_model_xml
+ xml = postprocess_model_xml(state["model"])
+ else:
+ # v1.4 and above use the class-based edit_model_xml function
+ xml = self.env.edit_model_xml(state["model"])
+ self.env.reset_from_xml_string(xml)
+ self.env.sim.reset()
+ if not self._is_v1:
+ # hide teleop visualization after restoring from model
+ self.env.sim.model.site_rgba[self.env.eef_site_id] = np.array([0., 0., 0., 0.])
+ self.env.sim.model.site_rgba[self.env.eef_cylinder_id] = np.array([0., 0., 0., 0.])
+ if "states" in state:
+ self.env.sim.set_state_from_flattened(state["states"])
+ self.env.sim.forward()
+ should_ret = True
+
+ if "goal" in state:
+ self.set_goal(**state["goal"])
+ if should_ret:
+ # only return obs if we've done a forward call - otherwise the observations will be garbage
+ return self.get_observation()
+ return None
+
+ def render(self, mode="human", height=None, width=None, camera_name="agentview"):
+ """
+ Render from simulation to either an on-screen window or off-screen to RGB array.
+
+ Args:
+ mode (str): pass "human" for on-screen rendering or "rgb_array" for off-screen rendering
+ height (int): height of image to render - only used if mode is "rgb_array"
+ width (int): width of image to render - only used if mode is "rgb_array"
+ camera_name (str): camera name to use for rendering
+ """
+ if mode == "human":
+ cam_id = self.env.sim.model.camera_name2id(camera_name)
+ self.env.viewer.set_camera(cam_id)
+ return self.env.render()
+ elif mode == "rgb_array":
+ im = self.env.sim.render(height=height, width=width, camera_name=camera_name)
+ if self.use_depth_obs:
+ # render() returns a tuple when self.use_depth_obs=True
+ return im[0][::-1]
+ return im[::-1]
+ else:
+ raise NotImplementedError("mode={} is not implemented".format(mode))
+
+ def get_observation(self, di=None):
+ """
+ Get current environment observation dictionary.
+
+ Args:
+ di (dict): current raw observation dictionary from robosuite to wrap and provide
+ as a dictionary. If not provided, will be queried from robosuite.
+ """
+ if di is None:
+ di = self.env._get_observations(force_update=True) if self._is_v1 else self.env._get_observation()
+ ret = {}
+ for k in di:
+ if (k in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality(key=k, obs_modality="rgb"):
+ ret[k] = di[k]
+ if self.postprocess_visual_obs:
+ ret[k] = ObsUtils.process_obs(obs=ret[k], obs_key=k)
+ elif (k in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality(key=k, obs_modality="depth"):
+ ret[k] = di[k]
+ if len(ret[k].shape) == 2:
+ ret[k] = ret[k][..., None] # (H, W, 1)
+ assert len(ret[k].shape) == 3
+ # scale entries in depth map to correspond to real distance.
+ ret[k] = self.get_real_depth_map(ret[k])
+ if self.postprocess_visual_obs:
+ ret[k] = ObsUtils.process_obs(obs=ret[k], obs_key=k)
+ elif (k in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality(key=k, obs_modality="depth"):
+ ret[k] = di[k]
+ if len(ret[k].shape) == 2:
+ ret[k] = ret[k][..., None] # (H, W, 1)
+ assert len(ret[k].shape) == 3
+ # scale entries in depth map to correspond to real distance.
+ ret[k] = self.get_real_depth_map(ret[k])
+ if self.postprocess_visual_obs:
+ ret[k] = ObsUtils.process_obs(obs=ret[k], obs_key=k)
+ elif k == "frontview_segmentation_instance" or k == "agentview_segmentation_instance":
+ ret[k] = di[k]
+ if len(ret[k].shape) == 2:
+ ret[k] = ret[k][..., None] # (H, W, 1)
+ elif k == "frontview_depth" or "agentview_depth":
+ ret[k] = di[k]
+ if len(ret[k].shape) == 2:
+ ret[k] = ret[k][..., None] # (H, W, 1)
+
+
+ # "object" key contains object information
+ if "object-state" in di.keys():
+ ret["object"] = np.array(di["object-state"])
+
+ if self._is_v1:
+ for robot in self.env.robots:
+ # add all robot-arm-specific observations. Note the (k not in ret) check
+ # ensures that we don't accidentally add robot wrist images a second time
+ pf = robot.robot_model.naming_prefix
+ for k in di:
+ if k.startswith(pf) and (k not in ret) and \
+ (not k.endswith("proprio-state")):
+ ret[k] = np.array(di[k])
+ else:
+ # minimal proprioception for older versions of robosuite
+ ret["proprio"] = np.array(di["robot-state"])
+ ret["eef_pos"] = np.array(di["eef_pos"])
+ ret["eef_quat"] = np.array(di["eef_quat"])
+ ret["gripper_qpos"] = np.array(di["gripper_qpos"])
+ return ret
+
+ def get_real_depth_map(self, depth_map):
+ """
+ Reproduced from https://github.com/ARISE-Initiative/robosuite/blob/c57e282553a4f42378f2635b9a3cbc4afba270fd/robosuite/utils/camera_utils.py#L106
+ since older versions of robosuite do not have this conversion from normalized depth values returned by MuJoCo
+ to real depth values.
+ """
+ # Make sure that depth values are normalized
+ assert np.all(depth_map >= 0.0) and np.all(depth_map <= 1.0)
+ extent = self.env.sim.model.stat.extent
+ far = self.env.sim.model.vis.map.zfar * extent
+ near = self.env.sim.model.vis.map.znear * extent
+ return near / (1.0 - depth_map * (1.0 - near / far))
+
+ def get_camera_intrinsic_matrix(self, camera_name, camera_height, camera_width):
+ """
+ Obtains camera intrinsic matrix.
+ Args:
+ camera_name (str): name of camera
+ camera_height (int): height of camera images in pixels
+ camera_width (int): width of camera images in pixels
+ Return:
+ K (np.array): 3x3 camera matrix
+ """
+ cam_id = self.env.sim.model.camera_name2id(camera_name)
+ fovy = self.env.sim.model.cam_fovy[cam_id]
+ f = 0.5 * camera_height / np.tan(fovy * np.pi / 360)
+ K = np.array([[f, 0, camera_width / 2], [0, f, camera_height / 2], [0, 0, 1]])
+ return K
+
+ def get_camera_extrinsic_matrix(self, camera_name):
+ """
+ Returns a 4x4 homogenous matrix corresponding to the camera pose in the
+ world frame. MuJoCo has a weird convention for how it sets up the
+ camera body axis, so we also apply a correction so that the x and y
+ axis are along the camera view and the z axis points along the
+ viewpoint.
+ Normal camera convention: https://docs.opencv.org/2.4/modules/calib3d/doc/camera_calibration_and_3d_reconstruction.html
+ Args:
+ camera_name (str): name of camera
+ Return:
+ R (np.array): 4x4 camera extrinsic matrix
+ """
+ cam_id = self.env.sim.model.camera_name2id(camera_name)
+ camera_pos = self.env.sim.data.cam_xpos[cam_id]
+ camera_rot = self.env.sim.data.cam_xmat[cam_id].reshape(3, 3)
+ R = T.make_pose(camera_pos, camera_rot)
+
+ # IMPORTANT! This is a correction so that the camera axis is set up along the viewpoint correctly.
+ camera_axis_correction = np.array(
+ [[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
+ )
+ R = R @ camera_axis_correction
+ return R
+
+ def get_camera_transform_matrix(self, camera_name, camera_height, camera_width):
+ """
+ Camera transform matrix to project from world coordinates to pixel coordinates.
+ Args:
+ camera_name (str): name of camera
+ camera_height (int): height of camera images in pixels
+ camera_width (int): width of camera images in pixels
+ Return:
+ K (np.array): 4x4 camera matrix to project from world coordinates to pixel coordinates
+ """
+ R = self.get_camera_extrinsic_matrix(camera_name=camera_name)
+ K = self.get_camera_intrinsic_matrix(
+ camera_name=camera_name, camera_height=camera_height, camera_width=camera_width
+ )
+ K_exp = np.eye(4)
+ K_exp[:3, :3] = K
+
+ # Takes a point in world, transforms to camera frame, and then projects onto image plane.
+ return K_exp @ T.pose_inv(R)
+
+ def get_state(self):
+ """
+ Get current environment simulator state as a dictionary. Should be compatible with @reset_to.
+ """
+ xml = self.env.sim.model.get_xml() # model xml file
+ state = np.array(self.env.sim.get_state().flatten()) # simulator state
+ return dict(model=xml, states=state)
+
+ def get_reward(self):
+ """
+ Get current reward.
+ """
+ return self.env.reward()
+
+ def get_goal(self):
+ """
+ Get goal observation. Not all environments support this.
+ """
+ return self.get_observation(self.env._get_goal())
+
+ def set_goal(self, **kwargs):
+ """
+ Set goal observation with external specification. Not all environments support this.
+ """
+ return self.env.set_goal(**kwargs)
+
+ def is_done(self):
+ """
+ Check if the task is done (not necessarily successful).
+ """
+
+ # Robosuite envs always rollout to fixed horizon.
+ return False
+
+ def is_success(self):
+ """
+ Check if the task condition(s) is reached. Should return a dictionary
+ { str: bool } with at least a "task" key for the overall task success,
+ and additional optional keys corresponding to other task criteria.
+ """
+ succ = self.env._check_success()
+ if isinstance(succ, dict):
+ assert "task" in succ
+ return succ
+ return { "task" : succ }
+
+ @property
+ def action_dimension(self):
+ """
+ Returns dimension of actions (int).
+ """
+ return self.env.action_spec[0].shape[0]
+
+ @property
+ def name(self):
+ """
+ Returns name of environment name (str).
+ """
+ return self._env_name
+
+ @property
+ def type(self):
+ """
+ Returns environment type (int) for this kind of environment.
+ This helps identify this env class.
+ """
+ return EB.EnvType.ROBOSUITE_TYPE
+
+ @property
+ def version(self):
+ """
+ Returns version of robosuite used for this environment, eg. 1.2.0
+ """
+ return robosuite.__version__
+
+ def serialize(self):
+ """
+ Save all information needed to re-instantiate this environment in a dictionary.
+ This is the same as @env_meta - environment metadata stored in hdf5 datasets,
+ and used in utils/env_utils.py.
+ """
+ return dict(
+ env_name=self.name,
+ env_version=self.version,
+ type=self.type,
+ env_kwargs=deepcopy(self._init_kwargs)
+ )
+
+ @classmethod
+ def create_for_data_processing(
+ cls,
+ env_name,
+ camera_names,
+ camera_height,
+ camera_width,
+ reward_shaping,
+ render=None,
+ render_offscreen=None,
+ use_image_obs=None,
+ use_depth_obs=None,
+ **kwargs,
+ ):
+ """
+ Create environment for processing datasets, which includes extracting
+ observations, labeling dense / sparse rewards, and annotating dones in
+ transitions.
+
+ Args:
+ env_name (str): name of environment
+ camera_names (list of str): list of camera names that correspond to image observations
+ camera_height (int): camera height for all cameras
+ camera_width (int): camera width for all cameras
+ reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards
+ render (bool or None): optionally override rendering behavior. Defaults to False.
+ render_offscreen (bool or None): optionally override rendering behavior. The default value is True if
+ @camera_names is non-empty, False otherwise.
+ use_image_obs (bool or None): optionally override rendering behavior. The default value is True if
+ @camera_names is non-empty, False otherwise.
+ use_depth_obs (bool): if True, use depth observations
+ """
+ is_v1 = (robosuite.__version__.split(".")[0] == "1")
+ has_camera = (len(camera_names) > 0)
+
+ new_kwargs = {
+ "reward_shaping": reward_shaping,
+ }
+
+ if has_camera:
+ if is_v1:
+ new_kwargs["camera_names"] = list(camera_names)
+ new_kwargs["camera_heights"] = camera_height
+ new_kwargs["camera_widths"] = camera_width
+ else:
+ assert len(camera_names) == 1
+ if has_camera:
+ new_kwargs["camera_name"] = camera_names[0]
+ new_kwargs["camera_height"] = camera_height
+ new_kwargs["camera_width"] = camera_width
+
+ kwargs.update(new_kwargs)
+
+ # also initialize obs utils so it knows which modalities are image modalities
+ image_modalities = list(camera_names)
+ depth_modalities = list(camera_names)
+ if is_v1:
+ image_modalities = ["{}_image".format(cn) for cn in camera_names]
+ depth_modalities = ["{}_depth".format(cn) for cn in camera_names]
+ elif has_camera:
+ # v0.3 only had support for one image, and it was named "image"
+ assert len(image_modalities) == 1
+ image_modalities = ["image"]
+ depth_modalities = ["depth"]
+ obs_modality_specs = {
+ "obs": {
+ "low_dim": [], # technically unused, so we don't have to specify all of them
+ "rgb": image_modalities,
+ }
+ }
+ if use_depth_obs:
+ obs_modality_specs["obs"]["depth"] = depth_modalities
+ ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs)
+
+ # note that @postprocess_visual_obs is False since this env's images will be written to a dataset
+ return cls(
+ env_name=env_name,
+ render=(False if render is None else render),
+ render_offscreen=(has_camera if render_offscreen is None else render_offscreen),
+ use_image_obs=(has_camera if use_image_obs is None else use_image_obs),
+ use_depth_obs=use_depth_obs,
+ postprocess_visual_obs=False,
+ **kwargs,
+ )
+
+ @property
+ def rollout_exceptions(self):
+ """
+ Return tuple of exceptions to except when doing rollouts. This is useful to ensure
+ that the entire training run doesn't crash because of a bad policy that causes unstable
+ simulation computations.
+ """
+ return tuple(MUJOCO_EXCEPTIONS)
+
+ @property
+ def base_env(self):
+ """
+ Grabs base simulation environment.
+ """
+ return self.env
+
+ def __repr__(self):
+ """
+ Pretty-print env description.
+ """
+ return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4)
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/envs/wrappers.py b/phantom/submodules/phantom-robomimic/robomimic/envs/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb46091ef33279ce9199f9d70e8add72818671a3
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/envs/wrappers.py
@@ -0,0 +1,222 @@
+"""
+A collection of useful environment wrappers.
+"""
+from copy import deepcopy
+import textwrap
+import numpy as np
+from collections import deque
+
+import robomimic.envs.env_base as EB
+
+
+class EnvWrapper(object):
+ """
+ Base class for all environment wrappers in robomimic.
+ """
+ def __init__(self, env):
+ """
+ Args:
+ env (EnvBase instance): The environment to wrap.
+ """
+ assert isinstance(env, EB.EnvBase) or isinstance(env, EnvWrapper)
+ self.env = env
+
+ @classmethod
+ def class_name(cls):
+ return cls.__name__
+
+ def _warn_double_wrap(self):
+ """
+ Utility function that checks if we're accidentally trying to double wrap an env
+ Raises:
+ Exception: [Double wrapping env]
+ """
+ env = self.env
+ while True:
+ if isinstance(env, EnvWrapper):
+ if env.class_name() == self.class_name():
+ raise Exception(
+ "Attempted to double wrap with Wrapper: {}".format(
+ self.__class__.__name__
+ )
+ )
+ env = env.env
+ else:
+ break
+
+ @property
+ def unwrapped(self):
+ """
+ Grabs unwrapped environment
+
+ Returns:
+ env (EnvBase instance): Unwrapped environment
+ """
+ if hasattr(self.env, "unwrapped"):
+ return self.env.unwrapped
+ else:
+ return self.env
+
+ def _to_string(self):
+ """
+ Subclasses should override this method to print out info about the
+ wrapper (such as arguments passed to it).
+ """
+ return ''
+
+ def __repr__(self):
+ """Pretty print environment."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ indent = ' ' * 4
+ if self._to_string() != '':
+ msg += textwrap.indent("\n" + self._to_string(), indent)
+ msg += textwrap.indent("\nenv={}".format(self.env), indent)
+ msg = header + '(' + msg + '\n)'
+ return msg
+
+ # this method is a fallback option on any methods the original env might support
+ def __getattr__(self, attr):
+ # using getattr ensures that both __getattribute__ and __getattr__ (fallback) get called
+ # (see https://stackoverflow.com/questions/3278077/difference-between-getattr-vs-getattribute)
+ orig_attr = getattr(self.env, attr)
+ if callable(orig_attr):
+
+ def hooked(*args, **kwargs):
+ result = orig_attr(*args, **kwargs)
+ # prevent wrapped_class from becoming unwrapped
+ if id(result) == id(self.env):
+ return self
+ return result
+
+ return hooked
+ else:
+ return orig_attr
+
+
+class FrameStackWrapper(EnvWrapper):
+ """
+ Wrapper for frame stacking observations during rollouts. The agent
+ receives a sequence of past observations instead of a single observation
+ when it calls @env.reset, @env.reset_to, or @env.step in the rollout loop.
+ """
+ def __init__(self, env, num_frames):
+ """
+ Args:
+ env (EnvBase instance): The environment to wrap.
+ num_frames (int): number of past observations (including current observation)
+ to stack together. Must be greater than 1 (otherwise this wrapper would
+ be a no-op).
+ """
+ assert num_frames > 1, "error: FrameStackWrapper must have num_frames > 1 but got num_frames of {}".format(num_frames)
+
+ super(FrameStackWrapper, self).__init__(env=env)
+ self.num_frames = num_frames
+
+ ### TODO: add action padding option + adding action to obs to include action history in obs ###
+
+ # keep track of last @num_frames observations for each obs key
+ self.obs_history = None
+
+ def _get_initial_obs_history(self, init_obs):
+ """
+ Helper method to get observation history from the initial observation, by
+ repeating it.
+
+ Returns:
+ obs_history (dict): a deque for each observation key, with an extra
+ leading dimension of 1 for each key (for easy concatenation later)
+ """
+ obs_history = {}
+ for k in init_obs:
+ obs_history[k] = deque(
+ [init_obs[k][None] for _ in range(self.num_frames)],
+ maxlen=self.num_frames,
+ )
+ return obs_history
+
+ def _get_stacked_obs_from_history(self):
+ """
+ Helper method to convert internal variable @self.obs_history to a
+ stacked observation where each key is a numpy array with leading dimension
+ @self.num_frames.
+ """
+ # concatenate all frames per key so we return a numpy array per key
+ return { k : np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history }
+
+ def cache_obs_history(self):
+ self.obs_history_cache = deepcopy(self.obs_history)
+
+ def uncache_obs_history(self):
+ self.obs_history = self.obs_history_cache
+ self.obs_history_cache = None
+
+ def reset(self):
+ """
+ Modify to return frame stacked observation which is @self.num_frames copies of
+ the initial observation.
+
+ Returns:
+ obs_stacked (dict): each observation key in original observation now has
+ leading shape @self.num_frames and consists of the previous @self.num_frames
+ observations
+ """
+ obs = self.env.reset()
+ self.timestep = 0 # always zero regardless of timestep type
+ self.update_obs(obs, reset=True)
+ self.obs_history = self._get_initial_obs_history(init_obs=obs)
+ return self._get_stacked_obs_from_history()
+
+ def reset_to(self, state):
+ """
+ Modify to return frame stacked observation which is @self.num_frames copies of
+ the initial observation.
+
+ Returns:
+ obs_stacked (dict): each observation key in original observation now has
+ leading shape @self.num_frames and consists of the previous @self.num_frames
+ observations
+ """
+ obs = self.env.reset_to(state)
+ self.timestep = 0 # always zero regardless of timestep type
+ self.update_obs(obs, reset=True)
+ self.obs_history = self._get_initial_obs_history(init_obs=obs)
+ return self._get_stacked_obs_from_history()
+
+ def step(self, action):
+ """
+ Modify to update the internal frame history and return frame stacked observation,
+ which will have leading dimension @self.num_frames for each key.
+
+ Args:
+ action (np.array): action to take
+
+ Returns:
+ obs_stacked (dict): each observation key in original observation now has
+ leading shape @self.num_frames and consists of the previous @self.num_frames
+ observations
+ reward (float): reward for this step
+ done (bool): whether the task is done
+ info (dict): extra information
+ """
+ obs, r, done, info = self.env.step(action)
+ self.update_obs(obs, action=action, reset=False)
+ # update frame history
+ for k in obs:
+ # make sure to have leading dim of 1 for easy concatenation
+ self.obs_history[k].append(obs[k][None])
+ obs_ret = self._get_stacked_obs_from_history()
+ return obs_ret, r, done, info
+
+ def update_obs(self, obs, action=None, reset=False):
+ obs["timesteps"] = np.array([self.timestep])
+
+ if reset:
+ obs["actions"] = np.zeros(self.env.action_dimension)
+ else:
+ self.timestep += 1
+ obs["actions"] = action[: self.env.action_dimension]
+
+ def _to_string(self):
+ """Info to pretty print."""
+ return "num_frames={}".format(self.num_frames)
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/bc.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/bc.json
new file mode 100644
index 0000000000000000000000000000000000000000..82ad783fbf330fecf0d59f97e346dc797dbaba1f
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/bc.json
@@ -0,0 +1,215 @@
+{
+ "algo_name": "bc",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../bc_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "all",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": false,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "hdf5_validation_filter_key": null,
+ "seq_length": 1,
+ "pad_seq_length": true,
+ "frame_stack": 1,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions",
+ "rewards",
+ "dones"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 100,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "optim_params": {
+ "policy": {
+ "optimizer_type": "adam",
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.1,
+ "epoch_schedule": [],
+ "scheduler_type": "multistep"
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ }
+ },
+ "loss": {
+ "l2_weight": 1.0,
+ "l1_weight": 0.0,
+ "cos_weight": 0.0
+ },
+ "actor_layer_dims": [
+ 1024,
+ 1024
+ ],
+ "gaussian": {
+ "enabled": false,
+ "fixed_std": false,
+ "init_std": 0.1,
+ "min_std": 0.01,
+ "std_activation": "softplus",
+ "low_noise_eval": true
+ },
+ "gmm": {
+ "enabled": false,
+ "num_modes": 5,
+ "min_std": 0.0001,
+ "std_activation": "softplus",
+ "low_noise_eval": true
+ },
+ "vae": {
+ "enabled": false,
+ "latent_dim": 14,
+ "latent_clip": null,
+ "kl_weight": 1.0,
+ "decoder": {
+ "is_conditioned": true,
+ "reconstruction_sum_across_elements": false
+ },
+ "prior": {
+ "learn": false,
+ "is_conditioned": false,
+ "use_gmm": false,
+ "gmm_num_modes": 10,
+ "gmm_learn_weights": false,
+ "use_categorical": false,
+ "categorical_dim": 10,
+ "categorical_gumbel_softmax_hard": false,
+ "categorical_init_temp": 1.0,
+ "categorical_temp_anneal_step": 0.001,
+ "categorical_min_temp": 0.3
+ },
+ "encoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "decoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "prior_layer_dims": [
+ 300,
+ 400
+ ]
+ },
+ "rnn": {
+ "enabled": false,
+ "horizon": 10,
+ "hidden_dim": 400,
+ "rnn_type": "LSTM",
+ "num_layers": 2,
+ "open_loop": false,
+ "kwargs": {
+ "bidirectional": false
+ }
+ },
+ "transformer": {
+ "enabled": false,
+ "context_length": 10,
+ "embed_dim": 512,
+ "num_layers": 6,
+ "num_heads": 8,
+ "emb_dropout": 0.1,
+ "attn_dropout": 0.1,
+ "block_output_dropout": 0.1,
+ "sinusoidal_embedding": false,
+ "activation": "gelu",
+ "supervise_all_steps": false,
+ "nn_parameter_for_timesteps": true
+ }
+ },
+ "observation": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ },
+ "meta": {
+ "hp_base_config_file": null,
+ "hp_keys": [],
+ "hp_values": []
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/bc_transformer.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/bc_transformer.json
new file mode 100644
index 0000000000000000000000000000000000000000..c28696cb0d6abc2d081570ed4dc2eaf16939a819
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/bc_transformer.json
@@ -0,0 +1,171 @@
+{
+ "algo_name": "bc",
+ "experiment": {
+ "name": "test",
+ "validate": true,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../bc_transformer_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "low_dim",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": false,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "seq_length": 1,
+ "pad_seq_length": true,
+ "frame_stack": 10,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 100,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "optim_params": {
+ "policy": {
+ "optimizer_type": "adamw",
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.1,
+ "epoch_schedule": [100],
+ "scheduler_type": "linear"
+ },
+ "regularization": {
+ "L2": 0.01
+ }
+ }
+ },
+ "loss": {
+ "l2_weight": 1.0,
+ "l1_weight": 0.0,
+ "cos_weight": 0.0
+ },
+ "actor_layer_dims": [],
+ "gaussian": {
+ "enabled": false
+ },
+ "gmm": {
+ "enabled": true,
+ "num_modes": 5,
+ "min_std": 0.0001,
+ "std_activation": "softplus",
+ "low_noise_eval": true
+ },
+ "vae": {
+ "enabled": false
+ },
+ "rnn": {
+ "enabled": false
+ },
+ "transformer": {
+ "enabled": true,
+ "supervise_all_steps": false,
+ "num_layers": 6,
+ "embed_dim": 512,
+ "num_heads": 8
+ }
+ },
+ "observation": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {
+ "feature_dimension": 64,
+ "backbone_class": "ResNet18Conv",
+ "backbone_kwargs": {
+ "pretrained": false,
+ "input_coord_conv": false
+ },
+ "pool_class": "SpatialSoftmax",
+ "pool_kwargs": {
+ "num_kp": 32,
+ "learnable_temperature": false,
+ "temperature": 1.0,
+ "noise_std": 0.0
+ }
+ },
+ "obs_randomizer_class": "CropRandomizer",
+ "obs_randomizer_kwargs": {
+ "crop_height": 76,
+ "crop_width": 76,
+ "num_crops": 1,
+ "pos_enc": false
+ }
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/bcq.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/bcq.json
new file mode 100644
index 0000000000000000000000000000000000000000..5ae9d907466f4278b418bcc1fb93aacb7fcb1e2a
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/bcq.json
@@ -0,0 +1,235 @@
+{
+ "algo_name": "bcq",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../bcq_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "all",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": true,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "hdf5_validation_filter_key": null,
+ "seq_length": 1,
+ "pad_seq_length": true,
+ "frame_stack": 1,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions",
+ "rewards",
+ "dones"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 100,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "optim_params": {
+ "critic": {
+ "learning_rate": {
+ "initial": 0.001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ },
+ "start_epoch": -1,
+ "end_epoch": -1
+ },
+ "action_sampler": {
+ "learning_rate": {
+ "initial": 0.001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ },
+ "start_epoch": -1,
+ "end_epoch": -1
+ },
+ "actor": {
+ "learning_rate": {
+ "initial": 0.001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ },
+ "start_epoch": -1,
+ "end_epoch": -1
+ }
+ },
+ "discount": 0.99,
+ "n_step": 1,
+ "target_tau": 0.005,
+ "infinite_horizon": false,
+ "critic": {
+ "use_huber": false,
+ "max_gradient_norm": null,
+ "value_bounds": null,
+ "num_action_samples": 10,
+ "num_action_samples_rollout": 100,
+ "ensemble": {
+ "n": 2,
+ "weight": 0.75
+ },
+ "distributional": {
+ "enabled": false,
+ "num_atoms": 51
+ },
+ "layer_dims": [
+ 300,
+ 400
+ ]
+ },
+ "action_sampler": {
+ "actor_layer_dims": [
+ 1024,
+ 1024
+ ],
+ "gmm": {
+ "enabled": false,
+ "num_modes": 5,
+ "min_std": 0.0001,
+ "std_activation": "softplus",
+ "low_noise_eval": true
+ },
+ "vae": {
+ "enabled": true,
+ "latent_dim": 14,
+ "latent_clip": null,
+ "kl_weight": 1.0,
+ "decoder": {
+ "is_conditioned": true,
+ "reconstruction_sum_across_elements": false
+ },
+ "prior": {
+ "learn": false,
+ "is_conditioned": false,
+ "use_gmm": false,
+ "gmm_num_modes": 10,
+ "gmm_learn_weights": false,
+ "use_categorical": false,
+ "categorical_dim": 10,
+ "categorical_gumbel_softmax_hard": false,
+ "categorical_init_temp": 1.0,
+ "categorical_temp_anneal_step": 0.001,
+ "categorical_min_temp": 0.3
+ },
+ "encoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "decoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "prior_layer_dims": [
+ 300,
+ 400
+ ]
+ },
+ "freeze_encoder_epoch": -1
+ },
+ "actor": {
+ "enabled": false,
+ "perturbation_scale": 0.05,
+ "layer_dims": [
+ 300,
+ 400
+ ]
+ }
+ },
+ "observation": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ },
+ "meta": {
+ "hp_base_config_file": null,
+ "hp_keys": [],
+ "hp_values": []
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/cql.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/cql.json
new file mode 100644
index 0000000000000000000000000000000000000000..a920efd6f01844971fba4881d73b762f7cf47ade
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/cql.json
@@ -0,0 +1,182 @@
+{
+ "algo_name": "cql",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../cql_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "all",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": true,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "hdf5_validation_filter_key": null,
+ "seq_length": 1,
+ "pad_seq_length": true,
+ "frame_stack": 1,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions",
+ "rewards",
+ "dones"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 1024,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "optim_params": {
+ "critic": {
+ "learning_rate": {
+ "initial": 0.001,
+ "decay_factor": 0.0,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ },
+ "actor": {
+ "learning_rate": {
+ "initial": 0.0003,
+ "decay_factor": 0.0,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ }
+ },
+ "discount": 0.99,
+ "n_step": 1,
+ "target_tau": 0.005,
+ "actor": {
+ "bc_start_steps": 0,
+ "target_entropy": "default",
+ "max_gradient_norm": null,
+ "net": {
+ "type": "gaussian",
+ "common": {
+ "std_activation": "exp",
+ "use_tanh": true,
+ "low_noise_eval": true
+ },
+ "gaussian": {
+ "init_last_fc_weight": 0.001,
+ "init_std": 0.3,
+ "fixed_std": false
+ }
+ },
+ "layer_dims": [
+ 300,
+ 400
+ ]
+ },
+ "critic": {
+ "use_huber": false,
+ "max_gradient_norm": null,
+ "value_bounds": null,
+ "num_action_samples": 1,
+ "cql_weight": 1.0,
+ "deterministic_backup": true,
+ "min_q_weight": 1.0,
+ "target_q_gap": 5.0,
+ "num_random_actions": 10,
+ "ensemble": {
+ "n": 2
+ },
+ "layer_dims": [
+ 300,
+ 400
+ ]
+ }
+ },
+ "observation": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ },
+ "meta": {
+ "hp_base_config_file": null,
+ "hp_keys": [],
+ "hp_values": []
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/diffusion_policy.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/diffusion_policy.json
new file mode 100644
index 0000000000000000000000000000000000000000..75936bb53d5155bac7730c741b20aec7d554ac73
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/diffusion_policy.json
@@ -0,0 +1,174 @@
+{
+ "algo_name": "diffusion_policy",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir":"../diffusion_policy_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "low_dim",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": false,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "seq_length": 15,
+ "pad_seq_length": true,
+ "frame_stack": 2,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 256,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "optim_params": {
+ "policy": {
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ }
+ },
+ "horizon": {
+ "observation_horizon": 2,
+ "action_horizon": 8,
+ "prediction_horizon": 16
+ },
+ "unet": {
+ "enabled": true,
+ "diffusion_step_embed_dim": 256,
+ "down_dims": [256,512,1024],
+ "kernel_size": 5,
+ "n_groups": 8
+ },
+ "ema": {
+ "enabled": true,
+ "power": 0.75
+ },
+ "ddpm": {
+ "enabled": true,
+ "num_train_timesteps": 100,
+ "num_inference_timesteps": 100,
+ "beta_schedule": "squaredcos_cap_v2",
+ "clip_sample": true,
+ "prediction_type": "epsilon"
+ },
+ "ddim": {
+ "enabled": false,
+ "num_train_timesteps": 100,
+ "num_inference_timesteps": 10,
+ "beta_schedule": "squaredcos_cap_v2",
+ "clip_sample": true,
+ "set_alpha_to_one": true,
+ "steps_offset": 0,
+ "prediction_type": "epsilon"
+ }
+ },
+ "observation": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {
+ "feature_dimension": 64,
+ "backbone_class": "ResNet18Conv",
+ "backbone_kwargs": {
+ "pretrained": false,
+ "input_coord_conv": false
+ },
+ "pool_class": "SpatialSoftmax",
+ "pool_kwargs": {
+ "num_kp": 32,
+ "learnable_temperature": false,
+ "temperature": 1.0,
+ "noise_std": 0.0
+ }
+ },
+ "obs_randomizer_class": "CropRandomizer",
+ "obs_randomizer_kwargs": {
+ "crop_height": 76,
+ "crop_width": 76,
+ "num_crops": 1,
+ "pos_enc": false
+ }
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/gl.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/gl.json
new file mode 100644
index 0000000000000000000000000000000000000000..39b4c2dbd65dad06afaaa1f88bd605a3477e3312
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/gl.json
@@ -0,0 +1,182 @@
+{
+ "algo_name": "gl",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../gl_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "all",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": true,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "hdf5_validation_filter_key": null,
+ "seq_length": 1,
+ "pad_seq_length": true,
+ "frame_stack": 1,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions",
+ "rewards",
+ "dones"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 100,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "optim_params": {
+ "goal_network": {
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ }
+ },
+ "subgoal_horizon": 10,
+ "ae": {
+ "planner_layer_dims": [
+ 300,
+ 400
+ ]
+ },
+ "vae": {
+ "enabled": true,
+ "latent_dim": 16,
+ "latent_clip": null,
+ "kl_weight": 1.0,
+ "decoder": {
+ "is_conditioned": true,
+ "reconstruction_sum_across_elements": false
+ },
+ "prior": {
+ "learn": false,
+ "is_conditioned": false,
+ "use_gmm": false,
+ "gmm_num_modes": 10,
+ "gmm_learn_weights": false,
+ "use_categorical": false,
+ "categorical_dim": 10,
+ "categorical_gumbel_softmax_hard": false,
+ "categorical_init_temp": 1.0,
+ "categorical_temp_anneal_step": 0.001,
+ "categorical_min_temp": 0.3
+ },
+ "encoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "decoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "prior_layer_dims": [
+ 300,
+ 400
+ ]
+ }
+ },
+ "observation": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "subgoal": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ },
+ "meta": {
+ "hp_base_config_file": null,
+ "hp_keys": [],
+ "hp_values": []
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/hbc.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/hbc.json
new file mode 100644
index 0000000000000000000000000000000000000000..26eff76a8f40e3fd787c7a561a91155369101b7e
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/hbc.json
@@ -0,0 +1,293 @@
+{
+ "algo_name": "hbc",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../hbc_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "all",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": true,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "hdf5_validation_filter_key": null,
+ "seq_length": 10,
+ "pad_seq_length": true,
+ "frame_stack": 1,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions",
+ "rewards",
+ "dones"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 100,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "mode": "separate",
+ "actor_use_random_subgoals": false,
+ "subgoal_update_interval": 10,
+ "latent_subgoal": {
+ "enabled": false,
+ "prior_correction": {
+ "enabled": false,
+ "num_samples": 100
+ }
+ },
+ "planner": {
+ "optim_params": {
+ "goal_network": {
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ }
+ },
+ "subgoal_horizon": 10,
+ "ae": {
+ "planner_layer_dims": [
+ 300,
+ 400
+ ]
+ },
+ "vae": {
+ "enabled": true,
+ "latent_dim": 16,
+ "latent_clip": null,
+ "kl_weight": 1.0,
+ "decoder": {
+ "is_conditioned": true,
+ "reconstruction_sum_across_elements": false
+ },
+ "prior": {
+ "learn": false,
+ "is_conditioned": false,
+ "use_gmm": false,
+ "gmm_num_modes": 10,
+ "gmm_learn_weights": false,
+ "use_categorical": false,
+ "categorical_dim": 10,
+ "categorical_gumbel_softmax_hard": false,
+ "categorical_init_temp": 1.0,
+ "categorical_temp_anneal_step": 0.001,
+ "categorical_min_temp": 0.3
+ },
+ "encoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "decoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "prior_layer_dims": [
+ 300,
+ 400
+ ]
+ }
+ },
+ "actor": {
+ "optim_params": {
+ "policy": {
+ "optimizer_type": "adam",
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.1,
+ "epoch_schedule": [],
+ "scheduler_type": "multistep"
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ }
+ },
+ "loss": {
+ "l2_weight": 1.0,
+ "l1_weight": 0.0,
+ "cos_weight": 0.0
+ },
+ "actor_layer_dims": [
+ 1024,
+ 1024
+ ],
+ "rnn": {
+ "enabled": true,
+ "horizon": 10,
+ "hidden_dim": 400,
+ "rnn_type": "LSTM",
+ "num_layers": 2,
+ "open_loop": false,
+ "kwargs": {
+ "bidirectional": false
+ }
+ },
+ "transformer": {
+ "enabled": false,
+ "context_length": 10,
+ "embed_dim": 512,
+ "num_layers": 6,
+ "num_heads": 8,
+ "emb_dropout": 0.1,
+ "attn_dropout": 0.1,
+ "block_output_dropout": 0.1,
+ "sinusoidal_embedding": false,
+ "activation": "gelu",
+ "supervise_all_steps": false,
+ "nn_parameter_for_timesteps": true
+ }
+ }
+ },
+ "observation": {
+ "planner": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "subgoal": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ },
+ "actor": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ }
+ },
+ "meta": {
+ "hp_base_config_file": null,
+ "hp_keys": [],
+ "hp_values": []
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/iql.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/iql.json
new file mode 100644
index 0000000000000000000000000000000000000000..4731788417924c649f1b92627fe6bf7f14668aac
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/iql.json
@@ -0,0 +1,192 @@
+{
+ "algo_name": "iql",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../iql_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "all",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": true,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "hdf5_validation_filter_key": null,
+ "seq_length": 1,
+ "pad_seq_length": true,
+ "frame_stack": 1,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions",
+ "rewards",
+ "dones"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 100,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "optim_params": {
+ "critic": {
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.0,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ },
+ "vf": {
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.0,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ },
+ "actor": {
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.0,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ }
+ },
+ "discount": 0.99,
+ "target_tau": 0.01,
+ "actor": {
+ "net": {
+ "type": "gaussian",
+ "common": {
+ "std_activation": "softplus",
+ "low_noise_eval": true,
+ "use_tanh": false
+ },
+ "gaussian": {
+ "init_last_fc_weight": 0.001,
+ "init_std": 0.3,
+ "fixed_std": false
+ },
+ "gmm": {
+ "num_modes": 5,
+ "min_std": 0.0001
+ }
+ },
+ "layer_dims": [
+ 300,
+ 400
+ ],
+ "max_gradient_norm": null
+ },
+ "critic": {
+ "ensemble": {
+ "n": 2
+ },
+ "layer_dims": [
+ 300,
+ 400
+ ],
+ "use_huber": false,
+ "max_gradient_norm": null
+ },
+ "adv": {
+ "clip_adv_value": null,
+ "beta": 1.0,
+ "use_final_clip": true
+ },
+ "vf_quantile": 0.9
+ },
+ "observation": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ },
+ "meta": {
+ "hp_base_config_file": null,
+ "hp_keys": [],
+ "hp_values": []
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/iris.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/iris.json
new file mode 100644
index 0000000000000000000000000000000000000000..6551663864a4d57d05d263de0069269ab115d8de
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/iris.json
@@ -0,0 +1,465 @@
+{
+ "algo_name": "iris",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 50,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": false,
+ "on_best_rollout_success_rate": true
+ },
+ "epoch_every_n_steps": 100,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": true,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 400,
+ "rate": 50,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../iris_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "all",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": true,
+ "hdf5_normalize_obs": false,
+ "hdf5_filter_key": null,
+ "hdf5_validation_filter_key": null,
+ "seq_length": 10,
+ "pad_seq_length": true,
+ "frame_stack": 1,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions",
+ "rewards",
+ "dones"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 100,
+ "num_epochs": 2000,
+ "seed": 1
+ },
+ "algo": {
+ "mode": "separate",
+ "actor_use_random_subgoals": false,
+ "subgoal_update_interval": 10,
+ "latent_subgoal": {
+ "enabled": false,
+ "prior_correction": {
+ "enabled": false,
+ "num_samples": 100
+ }
+ },
+ "value_planner": {
+ "planner": {
+ "optim_params": {
+ "goal_network": {
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ }
+ },
+ "subgoal_horizon": 10,
+ "ae": {
+ "planner_layer_dims": [
+ 300,
+ 400
+ ]
+ },
+ "vae": {
+ "enabled": true,
+ "latent_dim": 16,
+ "latent_clip": null,
+ "kl_weight": 1.0,
+ "decoder": {
+ "is_conditioned": true,
+ "reconstruction_sum_across_elements": false
+ },
+ "prior": {
+ "learn": false,
+ "is_conditioned": false,
+ "use_gmm": false,
+ "gmm_num_modes": 10,
+ "gmm_learn_weights": false,
+ "use_categorical": false,
+ "categorical_dim": 10,
+ "categorical_gumbel_softmax_hard": false,
+ "categorical_init_temp": 1.0,
+ "categorical_temp_anneal_step": 0.001,
+ "categorical_min_temp": 0.3
+ },
+ "encoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "decoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "prior_layer_dims": [
+ 300,
+ 400
+ ]
+ }
+ },
+ "value": {
+ "optim_params": {
+ "critic": {
+ "learning_rate": {
+ "initial": 0.001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ },
+ "start_epoch": -1,
+ "end_epoch": -1
+ },
+ "action_sampler": {
+ "learning_rate": {
+ "initial": 0.001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ },
+ "start_epoch": -1,
+ "end_epoch": -1
+ },
+ "actor": {
+ "learning_rate": {
+ "initial": 0.001,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ },
+ "start_epoch": -1,
+ "end_epoch": -1
+ }
+ },
+ "discount": 0.99,
+ "n_step": 1,
+ "target_tau": 0.005,
+ "infinite_horizon": false,
+ "critic": {
+ "use_huber": false,
+ "max_gradient_norm": null,
+ "value_bounds": null,
+ "num_action_samples": 10,
+ "num_action_samples_rollout": 100,
+ "ensemble": {
+ "n": 2,
+ "weight": 0.75
+ },
+ "distributional": {
+ "enabled": false,
+ "num_atoms": 51
+ },
+ "layer_dims": [
+ 300,
+ 400
+ ]
+ },
+ "action_sampler": {
+ "actor_layer_dims": [
+ 1024,
+ 1024
+ ],
+ "gmm": {
+ "enabled": false,
+ "num_modes": 5,
+ "min_std": 0.0001,
+ "std_activation": "softplus",
+ "low_noise_eval": true
+ },
+ "vae": {
+ "enabled": true,
+ "latent_dim": 14,
+ "latent_clip": null,
+ "kl_weight": 1.0,
+ "decoder": {
+ "is_conditioned": true,
+ "reconstruction_sum_across_elements": false
+ },
+ "prior": {
+ "learn": false,
+ "is_conditioned": false,
+ "use_gmm": false,
+ "gmm_num_modes": 10,
+ "gmm_learn_weights": false,
+ "use_categorical": false,
+ "categorical_dim": 10,
+ "categorical_gumbel_softmax_hard": false,
+ "categorical_init_temp": 1.0,
+ "categorical_temp_anneal_step": 0.001,
+ "categorical_min_temp": 0.3
+ },
+ "encoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "decoder_layer_dims": [
+ 300,
+ 400
+ ],
+ "prior_layer_dims": [
+ 300,
+ 400
+ ]
+ },
+ "freeze_encoder_epoch": -1
+ },
+ "actor": {
+ "enabled": false,
+ "perturbation_scale": 0.05,
+ "layer_dims": [
+ 300,
+ 400
+ ]
+ }
+ },
+ "num_samples": 100
+ },
+ "actor": {
+ "optim_params": {
+ "policy": {
+ "optimizer_type": "adam",
+ "learning_rate": {
+ "initial": 0.0001,
+ "decay_factor": 0.1,
+ "epoch_schedule": [],
+ "scheduler_type": "multistep"
+ },
+ "regularization": {
+ "L2": 0.0
+ }
+ }
+ },
+ "loss": {
+ "l2_weight": 1.0,
+ "l1_weight": 0.0,
+ "cos_weight": 0.0
+ },
+ "actor_layer_dims": [
+ 1024,
+ 1024
+ ],
+ "rnn": {
+ "enabled": true,
+ "horizon": 10,
+ "hidden_dim": 400,
+ "rnn_type": "LSTM",
+ "num_layers": 2,
+ "open_loop": false,
+ "kwargs": {
+ "bidirectional": false
+ }
+ },
+ "transformer": {
+ "enabled": false,
+ "context_length": 10,
+ "embed_dim": 512,
+ "num_layers": 6,
+ "num_heads": 8,
+ "emb_dropout": 0.1,
+ "attn_dropout": 0.1,
+ "block_output_dropout": 0.1,
+ "sinusoidal_embedding": false,
+ "activation": "gelu",
+ "supervise_all_steps": false,
+ "nn_parameter_for_timesteps": true
+ }
+ }
+ },
+ "observation": {
+ "value_planner": {
+ "planner": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "subgoal": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ },
+ "value": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ }
+ },
+ "actor": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ }
+ },
+ "meta": {
+ "hp_base_config_file": null,
+ "hp_keys": [],
+ "hp_values": []
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/exps/templates/td3_bc.json b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/td3_bc.json
new file mode 100644
index 0000000000000000000000000000000000000000..414a8f04f0cce7c9857207b1b1269ff10c3ee38b
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/exps/templates/td3_bc.json
@@ -0,0 +1,167 @@
+{
+ "algo_name": "td3_bc",
+ "experiment": {
+ "name": "test",
+ "validate": false,
+ "logging": {
+ "terminal_output_to_txt": true,
+ "log_tb": true,
+ "log_wandb": false,
+ "wandb_proj_name": "debug"
+ },
+ "save": {
+ "enabled": true,
+ "every_n_seconds": null,
+ "every_n_epochs": 20,
+ "epochs": [],
+ "on_best_validation": false,
+ "on_best_rollout_return": true,
+ "on_best_rollout_success_rate": false
+ },
+ "epoch_every_n_steps": 5000,
+ "validation_epoch_every_n_steps": 10,
+ "env": null,
+ "additional_envs": null,
+ "render": false,
+ "render_video": false,
+ "keep_all_videos": false,
+ "video_skip": 5,
+ "rollout": {
+ "enabled": true,
+ "n": 50,
+ "horizon": 1000,
+ "rate": 1,
+ "warmstart": 0,
+ "terminate_on_success": true
+ }
+ },
+ "train": {
+ "data": null,
+ "output_dir": "../td3_bc_trained_models",
+ "num_data_workers": 0,
+ "hdf5_cache_mode": "all",
+ "hdf5_use_swmr": true,
+ "hdf5_load_next_obs": true,
+ "hdf5_normalize_obs": true,
+ "hdf5_filter_key": null,
+ "hdf5_validation_filter_key": null,
+ "seq_length": 1,
+ "pad_seq_length": true,
+ "frame_stack": 1,
+ "pad_frame_stack": true,
+ "dataset_keys": [
+ "actions",
+ "rewards",
+ "dones"
+ ],
+ "goal_mode": null,
+ "cuda": true,
+ "batch_size": 256,
+ "num_epochs": 200,
+ "seed": 1
+ },
+ "algo": {
+ "optim_params": {
+ "critic": {
+ "learning_rate": {
+ "initial": 0.0003,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ },
+ "start_epoch": -1,
+ "end_epoch": -1
+ },
+ "actor": {
+ "learning_rate": {
+ "initial": 0.0003,
+ "decay_factor": 0.1,
+ "epoch_schedule": []
+ },
+ "regularization": {
+ "L2": 0.0
+ },
+ "start_epoch": -1,
+ "end_epoch": -1
+ }
+ },
+ "alpha": 2.5,
+ "discount": 0.99,
+ "n_step": 1,
+ "target_tau": 0.005,
+ "infinite_horizon": false,
+ "critic": {
+ "use_huber": false,
+ "max_gradient_norm": null,
+ "value_bounds": null,
+ "ensemble": {
+ "n": 2,
+ "weight": 1.0
+ },
+ "layer_dims": [
+ 256,
+ 256
+ ]
+ },
+ "actor": {
+ "update_freq": 2,
+ "noise_std": 0.2,
+ "noise_clip": 0.5,
+ "layer_dims": [
+ 256,
+ 256
+ ]
+ }
+ },
+ "observation": {
+ "modalities": {
+ "obs": {
+ "low_dim": [
+ "flat"
+ ],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ },
+ "goal": {
+ "low_dim": [],
+ "rgb": [],
+ "depth": [],
+ "scan": []
+ }
+ },
+ "encoder": {
+ "low_dim": {
+ "core_class": null,
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "rgb": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "depth": {
+ "core_class": "VisualCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ },
+ "scan": {
+ "core_class": "ScanCore",
+ "core_kwargs": {},
+ "obs_randomizer_class": null,
+ "obs_randomizer_kwargs": {}
+ }
+ }
+ },
+ "meta": {
+ "hp_base_config_file": null,
+ "hp_keys": [],
+ "hp_values": []
+ }
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/macros.py b/phantom/submodules/phantom-robomimic/robomimic/macros.py
new file mode 100644
index 0000000000000000000000000000000000000000..7496e93bbe5277c68573bdea7543c4a187ec490c
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/macros.py
@@ -0,0 +1,56 @@
+"""
+Set of global variables shared across robomimic
+"""
+# Sets debugging mode. Should be set at top-level script so that internal
+# debugging functionalities are made active
+DEBUG = False
+
+# Whether to visualize the before & after of an observation randomizer
+VISUALIZE_RANDOMIZER = False
+
+# wandb entity (eg. username or team name)
+WANDB_ENTITY = None
+
+# wandb api key (obtain from https://wandb.ai/authorize)
+# alternatively, set up wandb from terminal with `wandb login`
+WANDB_API_KEY = None
+
+### Slack Notifications ###
+
+# Token for sending slack notifications
+SLACK_TOKEN = None
+
+# User ID for user that should receive slack notifications
+SLACK_USER_ID = None
+
+
+### Local Sync Settings ###
+
+# By specifying this path, you can sync the most important results of training back to this folder
+RESULTS_SYNC_PATH = None
+
+# This will be automatically populated.
+RESULTS_SYNC_PATH_ABS = None
+
+
+### MagLev and NGC Cluster Settings ###
+
+# Whether training is happening on MagLev / NGC (should set this on repos hosted in MagLev / NGC scratch space or in Docker)
+USE_MAGLEV = False
+USE_NGC = False
+
+# When using MagLev / NGC, sync the most important results of training back to this directory in scratch space.
+# This path should be relative to the base scratch space directory (for MagLev) or an absolute path (for NGC)
+MAGLEV_SCRATCH_SYNC_PATH = None
+NGC_SCRATCH_SYNC_PATH = None
+
+try:
+ from robomimic.macros_private import *
+except ImportError:
+ from robomimic.utils.log_utils import log_warning
+ import robomimic
+ log_warning(
+ "No private macro file found!"\
+ "\nIt is recommended to use a private macro file"\
+ "\nTo setup, run: python {}/scripts/setup_macros.py".format(robomimic.__path__[0])
+ )
diff --git a/phantom/submodules/phantom-robomimic/robomimic/models/__init__.py b/phantom/submodules/phantom-robomimic/robomimic/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7460f9309af64c4578b547e0944c7e1366b5946c
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/models/__init__.py
@@ -0,0 +1 @@
+from .obs_core import EncoderCore, Randomizer
diff --git a/phantom/submodules/phantom-robomimic/robomimic/models/base_nets.py b/phantom/submodules/phantom-robomimic/robomimic/models/base_nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..18302a2c97a5278777adf2e626e8236d654143b2
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/models/base_nets.py
@@ -0,0 +1,1117 @@
+"""
+Contains torch Modules that correspond to basic network building blocks, like
+MLP, RNN, and CNN backbones.
+"""
+
+import math
+import abc
+import numpy as np
+import textwrap
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import transforms
+from torchvision import models as vision_models
+from torchvision import transforms
+
+import robomimic.utils.tensor_utils as TensorUtils
+
+
+CONV_ACTIVATIONS = {
+ "relu": nn.ReLU,
+ "None": None,
+ None: None,
+}
+
+
+def rnn_args_from_config(rnn_config):
+ """
+ Takes a Config object corresponding to RNN settings
+ (for example `config.algo.rnn` in BCConfig) and extracts
+ rnn kwargs for instantiating rnn networks.
+ """
+ return dict(
+ rnn_hidden_dim=rnn_config.hidden_dim,
+ rnn_num_layers=rnn_config.num_layers,
+ rnn_type=rnn_config.rnn_type,
+ rnn_kwargs=dict(rnn_config.kwargs),
+ )
+
+
+def transformer_args_from_config(transformer_config):
+ """
+ Takes a Config object corresponding to Transformer settings
+ (for example `config.algo.transformer` in BCConfig) and extracts
+ transformer kwargs for instantiating transformer networks.
+ """
+ transformer_args = dict(
+ transformer_context_length=transformer_config.context_length,
+ transformer_embed_dim=transformer_config.embed_dim,
+ transformer_num_heads=transformer_config.num_heads,
+ transformer_emb_dropout=transformer_config.emb_dropout,
+ transformer_attn_dropout=transformer_config.attn_dropout,
+ transformer_block_output_dropout=transformer_config.block_output_dropout,
+ transformer_sinusoidal_embedding=transformer_config.sinusoidal_embedding,
+ transformer_activation=transformer_config.activation,
+ transformer_nn_parameter_for_timesteps=transformer_config.nn_parameter_for_timesteps,
+ )
+
+ if "num_layers" in transformer_config:
+ transformer_args["transformer_num_layers"] = transformer_config.num_layers
+
+ return transformer_args
+
+
+class Module(torch.nn.Module):
+ """
+ Base class for networks. The only difference from torch.nn.Module is that it
+ requires implementing @output_shape.
+ """
+ @abc.abstractmethod
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ raise NotImplementedError
+
+
+class Sequential(torch.nn.Sequential, Module):
+ """
+ Compose multiple Modules together (defined above).
+ """
+ def __init__(self, *args, has_output_shape = True):
+ """
+ Args:
+ has_output_shape (bool, optional): indicates whether output_shape can be called on the Sequential module.
+ torch.nn modules do not have an output_shape, but Modules (defined above) do. Defaults to True.
+ """
+ for arg in args:
+ if has_output_shape:
+ assert isinstance(arg, Module)
+ else:
+ assert isinstance(arg, nn.Module)
+ torch.nn.Sequential.__init__(self, *args)
+ self.fixed = False
+ self.has_output_shape = has_output_shape
+
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ if not self.has_output_shape:
+ raise NotImplementedError("Output shape is not defined for this module")
+ out_shape = input_shape
+ for module in self:
+ out_shape = module.output_shape(out_shape)
+ return out_shape
+
+ def freeze(self):
+ self.fixed = True
+
+ def train(self, mode):
+ if self.fixed:
+ super().train(False)
+ else:
+ super().train(mode)
+
+
+class Parameter(Module):
+ """
+ A class that is a thin wrapper around a torch.nn.Parameter to make for easy saving
+ and optimization.
+ """
+ def __init__(self, init_tensor):
+ """
+ Args:
+ init_tensor (torch.Tensor): initial tensor
+ """
+ super(Parameter, self).__init__()
+ self.param = torch.nn.Parameter(init_tensor)
+
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ return list(self.param.shape)
+
+ def forward(self, inputs=None):
+ """
+ Forward call just returns the parameter tensor.
+ """
+ return self.param
+
+
+class Unsqueeze(Module):
+ """
+ Trivial class that unsqueezes the input. Useful for including in a nn.Sequential network
+ """
+ def __init__(self, dim):
+ super(Unsqueeze, self).__init__()
+ self.dim = dim
+
+ def output_shape(self, input_shape=None):
+ assert input_shape is not None
+ return input_shape + [1] if self.dim == -1 else input_shape[:self.dim + 1] + [1] + input_shape[self.dim + 1:]
+
+ def forward(self, x):
+ return x.unsqueeze(dim=self.dim)
+
+
+class Squeeze(Module):
+ """
+ Trivial class that squeezes the input. Useful for including in a nn.Sequential network
+ """
+
+ def __init__(self, dim):
+ super(Squeeze, self).__init__()
+ self.dim = dim
+
+ def output_shape(self, input_shape=None):
+ assert input_shape is not None
+ return input_shape[:self.dim] + input_shape[self.dim+1:] if input_shape[self.dim] == 1 else input_shape
+
+ def forward(self, x):
+ return x.squeeze(dim=self.dim)
+
+
+class MLP(Module):
+ """
+ Base class for simple Multi-Layer Perceptrons.
+ """
+ def __init__(
+ self,
+ input_dim,
+ output_dim,
+ layer_dims=(),
+ layer_func=nn.Linear,
+ layer_func_kwargs=None,
+ activation=nn.ReLU,
+ dropouts=None,
+ normalization=False,
+ output_activation=None,
+ ):
+ """
+ Args:
+ input_dim (int): dimension of inputs
+
+ output_dim (int): dimension of outputs
+
+ layer_dims ([int]): sequence of integers for the hidden layers sizes
+
+ layer_func: mapping per layer - defaults to Linear
+
+ layer_func_kwargs (dict): kwargs for @layer_func
+
+ activation: non-linearity per layer - defaults to ReLU
+
+ dropouts ([float]): if not None, adds dropout layers with the corresponding probabilities
+ after every layer. Must be same size as @layer_dims.
+
+ normalization (bool): if True, apply layer normalization after each layer
+
+ output_activation: if provided, applies the provided non-linearity to the output layer
+ """
+ super(MLP, self).__init__()
+ layers = []
+ dim = input_dim
+ if layer_func_kwargs is None:
+ layer_func_kwargs = dict()
+ if dropouts is not None:
+ assert(len(dropouts) == len(layer_dims))
+ for i, l in enumerate(layer_dims):
+ layers.append(layer_func(dim, l, **layer_func_kwargs))
+ if normalization:
+ layers.append(nn.LayerNorm(l))
+ layers.append(activation())
+ if dropouts is not None and dropouts[i] > 0.:
+ layers.append(nn.Dropout(dropouts[i]))
+ dim = l
+ layers.append(layer_func(dim, output_dim))
+ if output_activation is not None:
+ layers.append(output_activation())
+ self._layer_func = layer_func
+ self.nets = layers
+ self._model = nn.Sequential(*layers)
+
+ self._layer_dims = layer_dims
+ self._input_dim = input_dim
+ self._output_dim = output_dim
+ self._dropouts = dropouts
+ self._act = activation
+ self._output_act = output_activation
+
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ return [self._output_dim]
+
+ def forward(self, inputs):
+ """
+ Forward pass.
+ """
+ return self._model(inputs)
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = str(self.__class__.__name__)
+ act = None if self._act is None else self._act.__name__
+ output_act = None if self._output_act is None else self._output_act.__name__
+
+ indent = ' ' * 4
+ msg = "input_dim={}\noutput_dim={}\nlayer_dims={}\nlayer_func={}\ndropout={}\nact={}\noutput_act={}".format(
+ self._input_dim, self._output_dim, self._layer_dims,
+ self._layer_func.__name__, self._dropouts, act, output_act
+ )
+ msg = textwrap.indent(msg, indent)
+ msg = header + '(\n' + msg + '\n)'
+ return msg
+
+
+class RNN_Base(Module):
+ """
+ A wrapper class for a multi-step RNN and a per-step network.
+ """
+ def __init__(
+ self,
+ input_dim,
+ rnn_hidden_dim,
+ rnn_num_layers,
+ rnn_type="LSTM", # [LSTM, GRU]
+ rnn_kwargs=None,
+ per_step_net=None,
+ ):
+ """
+ Args:
+ input_dim (int): dimension of inputs
+
+ rnn_hidden_dim (int): RNN hidden dimension
+
+ rnn_num_layers (int): number of RNN layers
+
+ rnn_type (str): [LSTM, GRU]
+
+ rnn_kwargs (dict): kwargs for the torch.nn.LSTM / GRU
+
+ per_step_net: a network that runs per time step on top of the RNN output
+ """
+ super(RNN_Base, self).__init__()
+ self.per_step_net = per_step_net
+ if per_step_net is not None:
+ assert isinstance(per_step_net, Module), "RNN_Base: per_step_net is not instance of Module"
+
+ assert rnn_type in ["LSTM", "GRU"]
+ rnn_cls = nn.LSTM if rnn_type == "LSTM" else nn.GRU
+ rnn_kwargs = rnn_kwargs if rnn_kwargs is not None else {}
+ rnn_is_bidirectional = rnn_kwargs.get("bidirectional", False)
+
+ self.nets = rnn_cls(
+ input_size=input_dim,
+ hidden_size=rnn_hidden_dim,
+ num_layers=rnn_num_layers,
+ batch_first=True,
+ **rnn_kwargs,
+ )
+
+ self._hidden_dim = rnn_hidden_dim
+ self._num_layers = rnn_num_layers
+ self._rnn_type = rnn_type
+ self._num_directions = int(rnn_is_bidirectional) + 1 # 2 if bidirectional, 1 otherwise
+
+ @property
+ def rnn_type(self):
+ return self._rnn_type
+
+ def get_rnn_init_state(self, batch_size, device):
+ """
+ Get a default RNN state (zeros)
+ Args:
+ batch_size (int): batch size dimension
+
+ device: device the hidden state should be sent to.
+
+ Returns:
+ hidden_state (torch.Tensor or tuple): returns hidden state tensor or tuple of hidden state tensors
+ depending on the RNN type
+ """
+ h_0 = torch.zeros(self._num_layers * self._num_directions, batch_size, self._hidden_dim).to(device)
+ if self._rnn_type == "LSTM":
+ c_0 = torch.zeros(self._num_layers * self._num_directions, batch_size, self._hidden_dim).to(device)
+ return h_0, c_0
+ else:
+ return h_0
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+
+ # infer time dimension from input shape and add to per_step_net output shape
+ if self.per_step_net is not None:
+ out = self.per_step_net.output_shape(input_shape[1:])
+ if isinstance(out, dict):
+ out = {k: [input_shape[0]] + out[k] for k in out}
+ else:
+ out = [input_shape[0]] + out
+ else:
+ out = [input_shape[0], self._num_layers * self._hidden_dim]
+ return out
+
+ def forward(self, inputs, rnn_init_state=None, return_state=False):
+ """
+ Forward a sequence of inputs through the RNN and the per-step network.
+
+ Args:
+ inputs (torch.Tensor): tensor input of shape [B, T, D], where D is the RNN input size
+
+ rnn_init_state: rnn hidden state, initialize to zero state if set to None
+
+ return_state (bool): whether to return hidden state
+
+ Returns:
+ outputs: outputs of the per_step_net
+
+ rnn_state: return rnn state at the end if return_state is set to True
+ """
+ assert inputs.ndimension() == 3 # [B, T, D]
+ batch_size, seq_length, inp_dim = inputs.shape
+ if rnn_init_state is None:
+ rnn_init_state = self.get_rnn_init_state(batch_size, device=inputs.device)
+
+ outputs, rnn_state = self.nets(inputs, rnn_init_state)
+ if self.per_step_net is not None:
+ outputs = TensorUtils.time_distributed(outputs, self.per_step_net)
+
+ if return_state:
+ return outputs, rnn_state
+ else:
+ return outputs
+
+ def forward_step(self, inputs, rnn_state):
+ """
+ Forward a single step input through the RNN and per-step network, and return the new hidden state.
+ Args:
+ inputs (torch.Tensor): tensor input of shape [B, D], where D is the RNN input size
+
+ rnn_state: rnn hidden state, initialize to zero state if set to None
+
+ Returns:
+ outputs: outputs of the per_step_net
+
+ rnn_state: return the new rnn state
+ """
+ assert inputs.ndimension() == 2
+ inputs = TensorUtils.to_sequence(inputs)
+ outputs, rnn_state = self.forward(
+ inputs,
+ rnn_init_state=rnn_state,
+ return_state=True,
+ )
+ return outputs[:, 0], rnn_state
+
+
+"""
+================================================
+Visual Backbone Networks
+================================================
+"""
+class ConvBase(Module):
+ """
+ Base class for ConvNets.
+ """
+ def __init__(self):
+ super(ConvBase, self).__init__()
+
+ # dirty hack - re-implement to pass the buck onto subclasses from ABC parent
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ raise NotImplementedError
+
+ def forward(self, inputs):
+ x = self.nets(inputs)
+ if list(self.output_shape(list(inputs.shape)[1:])) != list(x.shape)[1:]:
+ raise ValueError('Size mismatch: expect size %s, but got size %s' % (
+ str(self.output_shape(list(inputs.shape)[1:])), str(list(x.shape)[1:]))
+ )
+ return x
+
+
+class ResNet18Conv(ConvBase):
+ """
+ A ResNet18 block that can be used to process input images.
+ """
+ def __init__(
+ self,
+ input_channel=3,
+ pretrained=False,
+ input_coord_conv=False,
+ ):
+ """
+ Args:
+ input_channel (int): number of input channels for input images to the network.
+ If not equal to 3, modifies first conv layer in ResNet to handle the number
+ of input channels.
+ pretrained (bool): if True, load pretrained weights for all ResNet layers.
+ input_coord_conv (bool): if True, use a coordinate convolution for the first layer
+ (a convolution where input channels are modified to encode spatial pixel location)
+ """
+ super(ResNet18Conv, self).__init__()
+ net = vision_models.resnet18(pretrained=pretrained)
+
+ if input_coord_conv:
+ net.conv1 = CoordConv2d(input_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ elif input_channel != 3:
+ net.conv1 = nn.Conv2d(input_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
+
+ # cut the last fc layer
+ self._input_coord_conv = input_coord_conv
+ self._input_channel = input_channel
+ self.nets = torch.nn.Sequential(*(list(net.children())[:-2]))
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ assert(len(input_shape) == 3)
+ out_h = int(math.ceil(input_shape[1] / 32.))
+ out_w = int(math.ceil(input_shape[2] / 32.))
+ return [512, out_h, out_w]
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ return header + '(input_channel={}, input_coord_conv={})'.format(self._input_channel, self._input_coord_conv)
+
+
+class R3MConv(ConvBase):
+ """
+ Base class for ConvNets pretrained with R3M (https://arxiv.org/abs/2203.12601)
+ """
+ def __init__(
+ self,
+ input_channel=3,
+ r3m_model_class='resnet18',
+ freeze=True,
+ ):
+ """
+ Using R3M pretrained observation encoder network proposed by https://arxiv.org/abs/2203.12601
+ Args:
+ input_channel (int): number of input channels for input images to the network.
+ If not equal to 3, modifies first conv layer in ResNet to handle the number
+ of input channels.
+ r3m_model_class (str): select one of the r3m pretrained model "resnet18", "resnet34" or "resnet50"
+ freeze (bool): if True, use a frozen R3M pretrained model.
+ """
+ super(R3MConv, self).__init__()
+
+ try:
+ from r3m import load_r3m
+ except ImportError:
+ print("WARNING: could not load r3m library! Please follow https://github.com/facebookresearch/r3m to install R3M")
+
+ net = load_r3m(r3m_model_class)
+
+ assert input_channel == 3 # R3M only support input image with channel size 3
+ assert r3m_model_class in ["resnet18", "resnet34", "resnet50"] # make sure the selected r3m model do exist
+
+ # cut the last fc layer
+ self._input_channel = input_channel
+ self._r3m_model_class = r3m_model_class
+ self._freeze = freeze
+ self._input_coord_conv = False
+ self._pretrained = True
+
+ preprocess = nn.Sequential(
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ )
+ self.nets = Sequential(*([preprocess] + list(net.module.convnet.children())), has_output_shape = False)
+ if freeze:
+ self.nets.freeze()
+
+ self.weight_sum = np.sum([param.cpu().data.numpy().sum() for param in self.nets.parameters()])
+ if freeze:
+ for param in self.nets.parameters():
+ param.requires_grad = False
+
+ self.nets.eval()
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ assert(len(input_shape) == 3)
+
+ if self._r3m_model_class == 'resnet50':
+ out_dim = 2048
+ else:
+ out_dim = 512
+
+ return [out_dim, 1, 1]
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ return header + '(input_channel={}, input_coord_conv={}, pretrained={}, freeze={})'.format(self._input_channel, self._input_coord_conv, self._pretrained, self._freeze)
+
+
+class MVPConv(ConvBase):
+ """
+ Base class for ConvNets pretrained with MVP (https://arxiv.org/abs/2203.06173)
+ """
+ def __init__(
+ self,
+ input_channel=3,
+ mvp_model_class='vitb-mae-egosoup',
+ freeze=True,
+ ):
+ """
+ Using MVP pretrained observation encoder network proposed by https://arxiv.org/abs/2203.06173
+ Args:
+ input_channel (int): number of input channels for input images to the network.
+ If not equal to 3, modifies first conv layer in ResNet to handle the number
+ of input channels.
+ mvp_model_class (str): select one of the mvp pretrained model "vits-mae-hoi", "vits-mae-in", "vits-sup-in", "vitb-mae-egosoup" or "vitl-256-mae-egosoup"
+ freeze (bool): if True, use a frozen MVP pretrained model.
+ """
+ super(MVPConv, self).__init__()
+
+ try:
+ import mvp
+ except ImportError:
+ print("WARNING: could not load mvp library! Please follow https://github.com/ir413/mvp to install MVP.")
+
+ self.nets = mvp.load(mvp_model_class)
+ if freeze:
+ self.nets.freeze()
+
+ assert input_channel == 3 # MVP only support input image with channel size 3
+ assert mvp_model_class in ["vits-mae-hoi", "vits-mae-in", "vits-sup-in", "vitb-mae-egosoup", "vitl-256-mae-egosoup"] # make sure the selected r3m model do exist
+
+ self._input_channel = input_channel
+ self._freeze = freeze
+ self._mvp_model_class = mvp_model_class
+ self._input_coord_conv = False
+ self._pretrained = True
+
+ if '256' in mvp_model_class:
+ input_img_size = 256
+ else:
+ input_img_size = 224
+ self.preprocess = nn.Sequential(
+ transforms.Resize(input_img_size)
+ )
+
+ def forward(self, inputs):
+ x = self.preprocess(inputs)
+ x = self.nets(x)
+ if list(self.output_shape(list(inputs.shape)[1:])) != list(x.shape)[1:]:
+ raise ValueError('Size mismatch: expect size %s, but got size %s' % (
+ str(self.output_shape(list(inputs.shape)[1:])), str(list(x.shape)[1:]))
+ )
+ return x
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ assert(len(input_shape) == 3)
+ if 'vitb' in self._mvp_model_class:
+ output_shape = [768]
+ elif 'vitl' in self._mvp_model_class:
+ output_shape = [1024]
+ else:
+ output_shape = [384]
+ return output_shape
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ return header + '(input_channel={}, input_coord_conv={}, pretrained={}, freeze={})'.format(self._input_channel, self._input_coord_conv, self._pretrained, self._freeze)
+
+
+class CoordConv2d(nn.Conv2d, Module):
+ """
+ 2D Coordinate Convolution
+
+ Source: An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution
+ https://arxiv.org/abs/1807.03247
+ (e.g. adds 2 channels per input feature map corresponding to (x, y) location on map)
+ """
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ padding_mode='zeros',
+ coord_encoding='position',
+ ):
+ """
+ Args:
+ in_channels: number of channels of the input tensor [C, H, W]
+ out_channels: number of output channels of the layer
+ kernel_size: convolution kernel size
+ stride: conv stride
+ padding: conv padding
+ dilation: conv dilation
+ groups: conv groups
+ bias: conv bias
+ padding_mode: conv padding mode
+ coord_encoding: type of coordinate encoding. currently only 'position' is implemented
+ """
+
+ assert(coord_encoding in ['position'])
+ self.coord_encoding = coord_encoding
+ if coord_encoding == 'position':
+ in_channels += 2 # two extra channel for positional encoding
+ self._position_enc = None # position encoding
+ else:
+ raise Exception("CoordConv2d: coord encoding {} not implemented".format(self.coord_encoding))
+ nn.Conv2d.__init__(
+ self,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ padding_mode=padding_mode
+ )
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+
+ # adds 2 to channel dimension
+ return [input_shape[0] + 2] + input_shape[1:]
+
+ def forward(self, input):
+ b, c, h, w = input.shape
+ if self.coord_encoding == 'position':
+ if self._position_enc is None:
+ pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
+ pos_y = pos_y.float().to(input.device) / float(h)
+ pos_x = pos_x.float().to(input.device) / float(w)
+ self._position_enc = torch.stack((pos_y, pos_x)).unsqueeze(0)
+ pos_enc = self._position_enc.expand(b, -1, -1, -1)
+ input = torch.cat((input, pos_enc), dim=1)
+ return super(CoordConv2d, self).forward(input)
+
+
+class ShallowConv(ConvBase):
+ """
+ A shallow convolutional encoder from https://rll.berkeley.edu/dsae/dsae.pdf
+ """
+ def __init__(self, input_channel=3, output_channel=32):
+ super(ShallowConv, self).__init__()
+ self._input_channel = input_channel
+ self._output_channel = output_channel
+ self.nets = nn.Sequential(
+ torch.nn.Conv2d(input_channel, 64, kernel_size=7, stride=2, padding=3),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(64, 32, kernel_size=1, stride=1, padding=0),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
+ )
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ assert(len(input_shape) == 3)
+ assert(input_shape[0] == self._input_channel)
+ out_h = int(math.floor(input_shape[1] / 2.))
+ out_w = int(math.floor(input_shape[2] / 2.))
+ return [self._output_channel, out_h, out_w]
+
+
+class Conv1dBase(Module):
+ """
+ Base class for stacked Conv1d layers.
+
+ Args:
+ input_channel (int): Number of channels for inputs to this network
+ activation (None or str): Per-layer activation to use. Defaults to "relu". Valid options are
+ currently {relu, None} for no activation
+ out_channels (list of int): Output channel size for each sequential Conv1d layer
+ kernel_size (list of int): Kernel sizes for each sequential Conv1d layer
+ stride (list of int): Stride sizes for each sequential Conv1d layer
+ conv_kwargs (dict): additional nn.Conv1D args to use, in list form, where the ith element corresponds to the
+ argument to be passed to the ith Conv1D layer.
+ See https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html for specific possible arguments.
+ """
+ def __init__(
+ self,
+ input_channel=1,
+ activation="relu",
+ out_channels=(32, 64, 64),
+ kernel_size=(8, 4, 2),
+ stride=(4, 2, 1),
+ **conv_kwargs,
+ ):
+ super(Conv1dBase, self).__init__()
+
+ # Get activation requested
+ activation = CONV_ACTIVATIONS[activation]
+
+ # Add layer kwargs
+ conv_kwargs["out_channels"] = out_channels
+ conv_kwargs["kernel_size"] = kernel_size
+ conv_kwargs["stride"] = stride
+
+ # Generate network
+ self.n_layers = len(out_channels)
+ layers = OrderedDict()
+ for i in range(self.n_layers):
+ layer_kwargs = {k: v[i] for k, v in conv_kwargs.items()}
+ layers[f'conv{i}'] = nn.Conv1d(
+ in_channels=input_channel,
+ **layer_kwargs,
+ )
+ if activation is not None:
+ layers[f'act{i}'] = activation()
+ input_channel = layer_kwargs["out_channels"]
+
+ # Store network
+ self.nets = nn.Sequential(layers)
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ channels, length = input_shape
+ for i in range(self.n_layers):
+ net = getattr(self.nets, f"conv{i}")
+ channels = net.out_channels
+ length = int((length + 2 * net.padding[0] - net.dilation[0] * (net.kernel_size[0] - 1) - 1) / net.stride[0]) + 1
+ return [channels, length]
+
+ def forward(self, inputs):
+ x = self.nets(inputs)
+ if list(self.output_shape(list(inputs.shape)[1:])) != list(x.shape)[1:]:
+ raise ValueError('Size mismatch: expect size %s, but got size %s' % (
+ str(self.output_shape(list(inputs.shape)[1:])), str(list(x.shape)[1:]))
+ )
+ return x
+
+
+"""
+================================================
+Pooling Networks
+================================================
+"""
+class SpatialSoftmax(ConvBase):
+ """
+ Spatial Softmax Layer.
+
+ Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al.
+ https://rll.berkeley.edu/dsae/dsae.pdf
+ """
+ def __init__(
+ self,
+ input_shape,
+ num_kp=32,
+ temperature=1.,
+ learnable_temperature=False,
+ output_variance=False,
+ noise_std=0.0,
+ ):
+ """
+ Args:
+ input_shape (list): shape of the input feature (C, H, W)
+ num_kp (int): number of keypoints (None for not using spatialsoftmax)
+ temperature (float): temperature term for the softmax.
+ learnable_temperature (bool): whether to learn the temperature
+ output_variance (bool): treat attention as a distribution, and compute second-order statistics to return
+ noise_std (float): add random spatial noise to the predicted keypoints
+ """
+ super(SpatialSoftmax, self).__init__()
+ assert len(input_shape) == 3
+ self._in_c, self._in_h, self._in_w = input_shape # (C, H, W)
+
+ if num_kp is not None:
+ self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
+ self._num_kp = num_kp
+ else:
+ self.nets = None
+ self._num_kp = self._in_c
+ self.learnable_temperature = learnable_temperature
+ self.output_variance = output_variance
+ self.noise_std = noise_std
+
+ if self.learnable_temperature:
+ # temperature will be learned
+ temperature = torch.nn.Parameter(torch.ones(1) * temperature, requires_grad=True)
+ self.register_parameter('temperature', temperature)
+ else:
+ # temperature held constant after initialization
+ temperature = torch.nn.Parameter(torch.ones(1) * temperature, requires_grad=False)
+ self.register_buffer('temperature', temperature)
+
+ pos_x, pos_y = np.meshgrid(
+ np.linspace(-1., 1., self._in_w),
+ np.linspace(-1., 1., self._in_h)
+ )
+ pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h * self._in_w)).float()
+ pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h * self._in_w)).float()
+ self.register_buffer('pos_x', pos_x)
+ self.register_buffer('pos_y', pos_y)
+
+ self.kps = None
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = format(str(self.__class__.__name__))
+ return header + '(num_kp={}, temperature={}, noise={})'.format(
+ self._num_kp, self.temperature.item(), self.noise_std)
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ assert(len(input_shape) == 3)
+ assert(input_shape[0] == self._in_c)
+ return [self._num_kp, 2]
+
+ def forward(self, feature):
+ """
+ Forward pass through spatial softmax layer. For each keypoint, a 2D spatial
+ probability distribution is created using a softmax, where the support is the
+ pixel locations. This distribution is used to compute the expected value of
+ the pixel location, which becomes a keypoint of dimension 2. K such keypoints
+ are created.
+
+ Returns:
+ out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly
+ keypoint variance of shape [B, K, 2, 2] corresponding to the covariance
+ under the 2D spatial softmax distribution
+ """
+ assert(feature.shape[1] == self._in_c)
+ assert(feature.shape[2] == self._in_h)
+ assert(feature.shape[3] == self._in_w)
+ if self.nets is not None:
+ feature = self.nets(feature)
+
+ # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
+ feature = feature.reshape(-1, self._in_h * self._in_w)
+ # 2d softmax normalization
+ attention = F.softmax(feature / self.temperature, dim=-1)
+ # [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
+ expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
+ expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
+ # stack to [B * K, 2]
+ expected_xy = torch.cat([expected_x, expected_y], 1)
+ # reshape to [B, K, 2]
+ feature_keypoints = expected_xy.view(-1, self._num_kp, 2)
+
+ if self.training:
+ noise = torch.randn_like(feature_keypoints) * self.noise_std
+ feature_keypoints += noise
+
+ if self.output_variance:
+ # treat attention as a distribution, and compute second-order statistics to return
+ expected_xx = torch.sum(self.pos_x * self.pos_x * attention, dim=1, keepdim=True)
+ expected_yy = torch.sum(self.pos_y * self.pos_y * attention, dim=1, keepdim=True)
+ expected_xy = torch.sum(self.pos_x * self.pos_y * attention, dim=1, keepdim=True)
+ var_x = expected_xx - expected_x * expected_x
+ var_y = expected_yy - expected_y * expected_y
+ var_xy = expected_xy - expected_x * expected_y
+ # stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix
+ feature_covar = torch.cat([var_x, var_xy, var_xy, var_y], 1).reshape(-1, self._num_kp, 2, 2)
+ feature_keypoints = (feature_keypoints, feature_covar)
+
+ if isinstance(feature_keypoints, tuple):
+ self.kps = (feature_keypoints[0].detach(), feature_keypoints[1].detach())
+ else:
+ self.kps = feature_keypoints.detach()
+ return feature_keypoints
+
+
+class SpatialMeanPool(Module):
+ """
+ Module that averages inputs across all spatial dimensions (dimension 2 and after),
+ leaving only the batch and channel dimensions.
+ """
+ def __init__(self, input_shape):
+ super(SpatialMeanPool, self).__init__()
+ assert len(input_shape) == 3 # [C, H, W]
+ self.in_shape = input_shape
+
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ return list(self.in_shape[:1]) # [C, H, W] -> [C]
+
+ def forward(self, inputs):
+ """Forward pass - average across all dimensions except batch and channel."""
+ return TensorUtils.flatten(inputs, begin_axis=2).mean(dim=2)
+
+
+class FeatureAggregator(Module):
+ """
+ Helpful class for aggregating features across a dimension. This is useful in
+ practice when training models that break an input image up into several patches
+ since features can be extraced per-patch using the same encoder and then
+ aggregated using this module.
+ """
+ def __init__(self, dim=1, agg_type="avg"):
+ super(FeatureAggregator, self).__init__()
+ self.dim = dim
+ self.agg_type = agg_type
+
+ def set_weight(self, w):
+ assert self.agg_type == "w_avg"
+ self.agg_weight = w
+
+ def clear_weight(self):
+ assert self.agg_type == "w_avg"
+ self.agg_weight = None
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ # aggregates on @self.dim, so it is removed from the output shape
+ return list(input_shape[:self.dim]) + list(input_shape[self.dim+1:])
+
+ def forward(self, x):
+ """Forward pooling pass."""
+ if self.agg_type == "avg":
+ # mean-pooling
+ return torch.mean(x, dim=1)
+ if self.agg_type == "w_avg":
+ # weighted mean-pooling
+ return torch.sum(x * self.agg_weight, dim=1)
+ raise Exception("unexpected agg type: {}".forward(self.agg_type))
diff --git a/phantom/submodules/phantom-robomimic/robomimic/models/distributions.py b/phantom/submodules/phantom-robomimic/robomimic/models/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..411efb1a8bbc6b0da7ac6f628357dc9c178b8780
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/models/distributions.py
@@ -0,0 +1,123 @@
+"""
+Contains distribution models used as parts of other networks. These
+classes usually inherit or emulate torch distributions.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributions as D
+
+
+class TanhWrappedDistribution(D.Distribution):
+ """
+ Class that wraps another valid torch distribution, such that sampled values from the base distribution are
+ passed through a tanh layer. The corresponding (log) probabilities are also modified accordingly.
+ Tanh Normal distribution - adapted from rlkit and CQL codebase
+ (https://github.com/aviralkumar2907/CQL/blob/d67dbe9cf5d2b96e3b462b6146f249b3d6569796/d4rl/rlkit/torch/distributions.py#L6).
+ """
+ def __init__(self, base_dist, scale=1.0, epsilon=1e-6):
+ """
+ Args:
+ base_dist (Distribution): Distribution to wrap with tanh output
+ scale (float): Scale of output
+ epsilon (float): Numerical stability epsilon when computing log-prob.
+ """
+ self.base_dist = base_dist
+ self.scale = scale
+ self.tanh_epsilon = epsilon
+ super(TanhWrappedDistribution, self).__init__()
+
+ def log_prob(self, value, pre_tanh_value=None):
+ """
+ Args:
+ value (torch.Tensor): some tensor to compute log probabilities for
+ pre_tanh_value: If specified, will not calculate atanh manually from @value. More numerically stable
+ """
+ value = value / self.scale
+ if pre_tanh_value is None:
+ one_plus_x = (1. + value).clamp(min=self.tanh_epsilon)
+ one_minus_x = (1. - value).clamp(min=self.tanh_epsilon)
+ pre_tanh_value = 0.5 * torch.log(one_plus_x / one_minus_x)
+ lp = self.base_dist.log_prob(pre_tanh_value)
+ tanh_lp = torch.log(1 - value * value + self.tanh_epsilon)
+ # In case the base dist already sums up the log probs, make sure we do the same
+ return lp - tanh_lp if len(lp.shape) == len(tanh_lp.shape) else lp - tanh_lp.sum(-1)
+
+ def sample(self, sample_shape=torch.Size(), return_pretanh_value=False):
+ """
+ Gradients will and should *not* pass through this operation.
+ See https://github.com/pytorch/pytorch/issues/4620 for discussion.
+ """
+ z = self.base_dist.sample(sample_shape=sample_shape).detach()
+
+ if return_pretanh_value:
+ return torch.tanh(z) * self.scale, z
+ else:
+ return torch.tanh(z) * self.scale
+
+ def rsample(self, sample_shape=torch.Size(), return_pretanh_value=False):
+ """
+ Sampling in the reparameterization case - for differentiable samples.
+ """
+ z = self.base_dist.rsample(sample_shape=sample_shape)
+
+ if return_pretanh_value:
+ return torch.tanh(z) * self.scale, z
+ else:
+ return torch.tanh(z) * self.scale
+
+ @property
+ def mean(self):
+ return self.base_dist.mean
+
+ @property
+ def stddev(self):
+ return self.base_dist.stddev
+
+
+class DiscreteValueDistribution(object):
+ """
+ Extension to torch categorical probability distribution in order to keep track
+ of the support (categorical values, or in this case, value atoms). This is
+ used for distributional value networks.
+ """
+ def __init__(self, values, probs=None, logits=None):
+ """
+ Creates a categorical distribution parameterized by either @probs or
+ @logits (but not both). Expects inputs to be consistent in shape
+ for broadcasting operations (e.g. multiplication).
+ """
+ self._values = values
+ self._categorical_dist = D.Categorical(probs=probs, logits=logits)
+
+ @property
+ def values(self):
+ return self._values
+
+ @property
+ def probs(self):
+ return self._categorical_dist.probs
+
+ @property
+ def logits(self):
+ return self._categorical_dist.logits
+
+ def mean(self):
+ """
+ Categorical distribution mean, taking the value support into account.
+ """
+ return (self._categorical_dist.probs * self._values).sum(dim=-1)
+
+ def variance(self):
+ """
+ Categorical distribution variance, taking the value support into account.
+ """
+ dist_squared = (self.mean().unsqueeze(-1) - self.values).pow(2)
+ return (self._categorical_dist.probs * dist_squared).sum(dim=-1)
+
+ def sample(self, sample_shape=torch.Size()):
+ """
+ Sample from the distribution. Make sure to return value atoms, not categorical class indices.
+ """
+ inds = self._categorical_dist.sample(sample_shape=sample_shape)
+ return torch.gather(self.values, inds, dim=-1)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/models/obs_core.py b/phantom/submodules/phantom-robomimic/robomimic/models/obs_core.py
new file mode 100644
index 0000000000000000000000000000000000000000..4183043837c0eda0901f38a93c348e4085128b96
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/models/obs_core.py
@@ -0,0 +1,829 @@
+"""
+Contains torch Modules for core observation processing blocks
+such as encoders (e.g. EncoderCore, VisualCore, ScanCore, ...)
+and randomizers (e.g. Randomizer, CropRandomizer).
+"""
+
+import abc
+import numpy as np
+import textwrap
+import random
+
+import torch
+import torch.nn as nn
+from torchvision.transforms import Lambda, Compose
+import torchvision.transforms.functional as TVF
+
+import robomimic.models.base_nets as BaseNets
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.obs_utils as ObsUtils
+from robomimic.utils.python_utils import extract_class_init_kwargs_from_dict
+
+# NOTE: this is required for the backbone classes to be found by the `eval` call in the core networks
+from robomimic.models.base_nets import *
+from robomimic.utils.vis_utils import visualize_image_randomizer
+from robomimic.macros import VISUALIZE_RANDOMIZER
+
+
+"""
+================================================
+Encoder Core Networks (Abstract class)
+================================================
+"""
+class EncoderCore(BaseNets.Module):
+ """
+ Abstract class used to categorize all cores used to encode observations
+ """
+ def __init__(self, input_shape):
+ self.input_shape = input_shape
+ super(EncoderCore, self).__init__()
+
+ def __init_subclass__(cls, **kwargs):
+ """
+ Hook method to automatically register all valid subclasses so we can keep track of valid observation encoders
+ in a global dict.
+
+ This global dict stores mapping from observation encoder network name to class.
+ We keep track of these registries to enable automated class inference at runtime, allowing
+ users to simply extend our base encoder class and refer to that class in string form
+ in their config, without having to manually register their class internally.
+ This also future-proofs us for any additional encoder classes we would
+ like to add ourselves.
+ """
+ ObsUtils.register_encoder_core(cls)
+
+
+"""
+================================================
+Visual Core Networks (Backbone + Pool)
+================================================
+"""
+class VisualCore(EncoderCore, BaseNets.ConvBase):
+ """
+ A network block that combines a visual backbone network with optional pooling
+ and linear layers.
+ """
+ def __init__(
+ self,
+ input_shape,
+ backbone_class="ResNet18Conv",
+ pool_class="SpatialSoftmax",
+ backbone_kwargs=None,
+ pool_kwargs=None,
+ flatten=True,
+ feature_dimension=64,
+ ):
+ """
+ Args:
+ input_shape (tuple): shape of input (not including batch dimension)
+ backbone_class (str): class name for the visual backbone network. Defaults
+ to "ResNet18Conv".
+ pool_class (str): class name for the visual feature pooler (optional)
+ Common options are "SpatialSoftmax" and "SpatialMeanPool". Defaults to
+ "SpatialSoftmax".
+ backbone_kwargs (dict): kwargs for the visual backbone network (optional)
+ pool_kwargs (dict): kwargs for the visual feature pooler (optional)
+ flatten (bool): whether to flatten the visual features
+ feature_dimension (int): if not None, add a Linear layer to
+ project output into a desired feature dimension
+ """
+ super(VisualCore, self).__init__(input_shape=input_shape)
+ self.flatten = flatten
+
+ if backbone_kwargs is None:
+ backbone_kwargs = dict()
+
+ # add input channel dimension to visual core inputs
+ backbone_kwargs["input_channel"] = input_shape[0]
+
+ # extract only relevant kwargs for this specific backbone
+ backbone_kwargs = extract_class_init_kwargs_from_dict(cls=eval(backbone_class), dic=backbone_kwargs, copy=True)
+
+ # visual backbone
+ assert isinstance(backbone_class, str)
+ self.backbone = eval(backbone_class)(**backbone_kwargs)
+
+ assert isinstance(self.backbone, BaseNets.ConvBase)
+
+ feat_shape = self.backbone.output_shape(input_shape)
+ net_list = [self.backbone]
+
+ # maybe make pool net
+ if pool_class is not None:
+ assert isinstance(pool_class, str)
+ # feed output shape of backbone to pool net
+ if pool_kwargs is None:
+ pool_kwargs = dict()
+ # extract only relevant kwargs for this specific backbone
+ pool_kwargs["input_shape"] = feat_shape
+ pool_kwargs = extract_class_init_kwargs_from_dict(cls=eval(pool_class), dic=pool_kwargs, copy=True)
+ self.pool = eval(pool_class)(**pool_kwargs)
+ assert isinstance(self.pool, BaseNets.Module)
+
+ feat_shape = self.pool.output_shape(feat_shape)
+ net_list.append(self.pool)
+ else:
+ self.pool = None
+
+ # flatten layer
+ if self.flatten:
+ net_list.append(torch.nn.Flatten(start_dim=1, end_dim=-1))
+
+ # maybe linear layer
+ self.feature_dimension = feature_dimension
+ if feature_dimension is not None:
+ assert self.flatten
+ linear = torch.nn.Linear(int(np.prod(feat_shape)), feature_dimension)
+ net_list.append(linear)
+
+ self.nets = nn.Sequential(*net_list)
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ if self.feature_dimension is not None:
+ # linear output
+ return [self.feature_dimension]
+ feat_shape = self.backbone.output_shape(input_shape)
+ if self.pool is not None:
+ # pool output
+ feat_shape = self.pool.output_shape(feat_shape)
+ # backbone + flat output
+ if self.flatten:
+ return [np.prod(feat_shape)]
+ else:
+ return feat_shape
+
+ def forward(self, inputs):
+ """
+ Forward pass through visual core.
+ """
+ ndim = len(self.input_shape)
+ assert tuple(inputs.shape)[-ndim:] == tuple(self.input_shape)
+ return super(VisualCore, self).forward(inputs)
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ indent = ' ' * 2
+ msg += textwrap.indent(
+ "\ninput_shape={}\noutput_shape={}".format(self.input_shape, self.output_shape(self.input_shape)), indent)
+ msg += textwrap.indent("\nbackbone_net={}".format(self.backbone), indent)
+ msg += textwrap.indent("\npool_net={}".format(self.pool), indent)
+ msg = header + '(' + msg + '\n)'
+ return msg
+
+
+"""
+================================================
+Scan Core Networks (Conv1D Sequential + Pool)
+================================================
+"""
+class ScanCore(EncoderCore, BaseNets.ConvBase):
+ """
+ A network block that combines a Conv1D backbone network with optional pooling
+ and linear layers.
+ """
+ def __init__(
+ self,
+ input_shape,
+ conv_kwargs=None,
+ conv_activation="relu",
+ pool_class=None,
+ pool_kwargs=None,
+ flatten=True,
+ feature_dimension=None,
+ ):
+ """
+ Args:
+ input_shape (tuple): shape of input (not including batch dimension)
+ conv_kwargs (dict): kwargs for the conv1d backbone network. Should contain lists for the following values:
+ out_channels (int)
+ kernel_size (int)
+ stride (int)
+ ...
+
+ If not specified, or an empty dictionary is specified, some default settings will be used.
+ conv_activation (str or None): Activation to use between conv layers. Default is relu.
+ Currently, valid options are {relu}
+ pool_class (str): class name for the visual feature pooler (optional)
+ Common options are "SpatialSoftmax" and "SpatialMeanPool"
+ pool_kwargs (dict): kwargs for the visual feature pooler (optional)
+ flatten (bool): whether to flatten the network output
+ feature_dimension (int): if not None, add a Linear layer to
+ project output into a desired feature dimension (note: flatten must be set to True!)
+ """
+ super(ScanCore, self).__init__(input_shape=input_shape)
+ self.flatten = flatten
+ self.feature_dimension = feature_dimension
+
+ if conv_kwargs is None:
+ conv_kwargs = dict()
+
+ # Generate backbone network
+ # N input channels is assumed to be the first dimension
+ self.backbone = BaseNets.Conv1dBase(
+ input_channel=self.input_shape[0],
+ activation=conv_activation,
+ **conv_kwargs,
+ )
+ feat_shape = self.backbone.output_shape(input_shape=input_shape)
+
+ # Create netlist of all generated networks
+ net_list = [self.backbone]
+
+ # Possibly add pooling network
+ if pool_class is not None:
+ # Add an unsqueeze network so that the shape is correct to pass to pooling network
+ self.unsqueeze = Unsqueeze(dim=-1)
+ net_list.append(self.unsqueeze)
+ # Get output shape
+ feat_shape = self.unsqueeze.output_shape(feat_shape)
+ # Create pooling network
+ self.pool = eval(pool_class)(input_shape=feat_shape, **pool_kwargs)
+ net_list.append(self.pool)
+ feat_shape = self.pool.output_shape(feat_shape)
+ else:
+ self.unsqueeze, self.pool = None, None
+
+ # flatten layer
+ if self.flatten:
+ net_list.append(torch.nn.Flatten(start_dim=1, end_dim=-1))
+
+ # maybe linear layer
+ if self.feature_dimension is not None:
+ assert self.flatten
+ linear = torch.nn.Linear(int(np.prod(feat_shape)), self.feature_dimension)
+ net_list.append(linear)
+
+ # Generate final network
+ self.nets = nn.Sequential(*net_list)
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ if self.feature_dimension is not None:
+ # linear output
+ return [self.feature_dimension]
+ feat_shape = self.backbone.output_shape(input_shape)
+ if self.pool is not None:
+ # pool output
+ feat_shape = self.pool.output_shape(self.unsqueeze.output_shape(feat_shape))
+ # backbone + flat output
+ return [np.prod(feat_shape)] if self.flatten else feat_shape
+
+ def forward(self, inputs):
+ """
+ Forward pass through visual core.
+ """
+ ndim = len(self.input_shape)
+ assert tuple(inputs.shape)[-ndim:] == tuple(self.input_shape)
+ return super(ScanCore, self).forward(inputs)
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ indent = ' ' * 2
+ msg += textwrap.indent(
+ "\ninput_shape={}\noutput_shape={}".format(self.input_shape, self.output_shape(self.input_shape)), indent)
+ msg += textwrap.indent("\nbackbone_net={}".format(self.backbone), indent)
+ msg += textwrap.indent("\npool_net={}".format(self.pool), indent)
+ msg = header + '(' + msg + '\n)'
+ return msg
+
+
+"""
+================================================
+Observation Randomizer Networks
+================================================
+"""
+class Randomizer(BaseNets.Module):
+ """
+ Base class for randomizer networks. Each randomizer should implement the @output_shape_in,
+ @output_shape_out, @forward_in, and @forward_out methods. The randomizer's @forward_in
+ method is invoked on raw inputs, and @forward_out is invoked on processed inputs
+ (usually processed by a @VisualCore instance). Note that the self.training property
+ can be used to change the randomizer's behavior at train vs. test time.
+ """
+ def __init__(self):
+ super(Randomizer, self).__init__()
+
+ def __init_subclass__(cls, **kwargs):
+ """
+ Hook method to automatically register all valid subclasses so we can keep track of valid observation randomizers
+ in a global dict.
+
+ This global dict stores mapping from observation randomizer network name to class.
+ We keep track of these registries to enable automated class inference at runtime, allowing
+ users to simply extend our base randomizer class and refer to that class in string form
+ in their config, without having to manually register their class internally.
+ This also future-proofs us for any additional randomizer classes we would
+ like to add ourselves.
+ """
+ ObsUtils.register_randomizer(cls)
+
+ def output_shape(self, input_shape=None):
+ """
+ This function is unused. See @output_shape_in and @output_shape_out.
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def output_shape_in(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module. Corresponds to
+ the @forward_in operation, where raw inputs (usually observation modalities)
+ are passed in.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def output_shape_out(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module. Corresponds to
+ the @forward_out operation, where processed inputs (usually encoded observation
+ modalities) are passed in.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ raise NotImplementedError
+
+ def forward_in(self, inputs):
+ """
+ Randomize raw inputs if training.
+ """
+ if self.training:
+ randomized_inputs = self._forward_in(inputs=inputs)
+ if VISUALIZE_RANDOMIZER:
+ num_samples_to_visualize = min(4, inputs.shape[0])
+ self._visualize(inputs, randomized_inputs, num_samples_to_visualize=num_samples_to_visualize)
+ return randomized_inputs
+ else:
+ return self._forward_in_eval(inputs)
+
+ def forward_out(self, inputs):
+ """
+ Processing for network outputs.
+ """
+ if self.training:
+ return self._forward_out(inputs)
+ else:
+ return self._forward_out_eval(inputs)
+
+ @abc.abstractmethod
+ def _forward_in(self, inputs):
+ """
+ Randomize raw inputs.
+ """
+ raise NotImplementedError
+
+ def _forward_in_eval(self, inputs):
+ """
+ Test-time behavior for the randomizer
+ """
+ return inputs
+
+ @abc.abstractmethod
+ def _forward_out(self, inputs):
+ """
+ Processing for network outputs.
+ """
+ return inputs
+
+ def _forward_out_eval(self, inputs):
+ """
+ Test-time behavior for the randomizer
+ """
+ return inputs
+
+ @abc.abstractmethod
+ def _visualize(self, pre_random_input, randomized_input, num_samples_to_visualize=2):
+ """
+ Visualize the original input and the randomized input for _forward_in for debugging purposes.
+ """
+ pass
+
+
+class CropRandomizer(Randomizer):
+ """
+ Randomly sample crops at input, and then average across crop features at output.
+ """
+ def __init__(
+ self,
+ input_shape,
+ crop_height=76,
+ crop_width=76,
+ num_crops=1,
+ pos_enc=False,
+ ):
+ """
+ Args:
+ input_shape (tuple, list): shape of input (not including batch dimension)
+ crop_height (int): crop height
+ crop_width (int): crop width
+ num_crops (int): number of random crops to take
+ pos_enc (bool): if True, add 2 channels to the output to encode the spatial
+ location of the cropped pixels in the source image
+ """
+ super(CropRandomizer, self).__init__()
+
+ assert len(input_shape) == 3 # (C, H, W)
+ assert crop_height < input_shape[1]
+ assert crop_width < input_shape[2]
+
+ self.input_shape = input_shape
+ self.crop_height = crop_height
+ self.crop_width = crop_width
+ self.num_crops = num_crops
+ self.pos_enc = pos_enc
+
+ def output_shape_in(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module. Corresponds to
+ the @forward_in operation, where raw inputs (usually observation modalities)
+ are passed in.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+
+ # outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
+ # the number of crops are reshaped into the batch dimension, increasing the batch
+ # size from B to B * N
+ out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
+ return [out_c, self.crop_height, self.crop_width]
+
+ def output_shape_out(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module. Corresponds to
+ the @forward_out operation, where processed inputs (usually encoded observation
+ modalities) are passed in.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+
+ # since the forward_out operation splits [B * N, ...] -> [B, N, ...]
+ # and then pools to result in [B, ...], only the batch dimension changes,
+ # and so the other dimensions retain their shape.
+ return list(input_shape)
+
+ def _forward_in(self, inputs):
+ """
+ Samples N random crops for each input in the batch, and then reshapes
+ inputs to [B * N, ...].
+ """
+ assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
+ out, _ = ObsUtils.sample_random_image_crops(
+ images=inputs,
+ crop_height=self.crop_height,
+ crop_width=self.crop_width,
+ num_crops=self.num_crops,
+ pos_enc=self.pos_enc,
+ )
+ # [B, N, ...] -> [B * N, ...]
+ return TensorUtils.join_dimensions(out, 0, 1)
+
+ def _forward_in_eval(self, inputs):
+ """
+ Do center crops during eval
+ """
+ assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
+ inputs = inputs.permute(*range(inputs.dim()-3), inputs.dim()-2, inputs.dim()-1, inputs.dim()-3)
+ out = ObsUtils.center_crop(inputs, self.crop_height, self.crop_width)
+ out = out.permute(*range(out.dim()-3), out.dim()-1, out.dim()-3, out.dim()-2)
+ return out
+
+ def _forward_out(self, inputs):
+ """
+ Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
+ to result in shape [B, ...] to make sure the network output is consistent with
+ what would have happened if there were no randomization.
+ """
+ batch_size = (inputs.shape[0] // self.num_crops)
+ out = TensorUtils.reshape_dimensions(inputs, begin_axis=0, end_axis=0,
+ target_dims=(batch_size, self.num_crops))
+ return out.mean(dim=1)
+
+ def _visualize(self, pre_random_input, randomized_input, num_samples_to_visualize=2):
+ batch_size = pre_random_input.shape[0]
+ random_sample_inds = torch.randint(0, batch_size, size=(num_samples_to_visualize,))
+ pre_random_input_np = TensorUtils.to_numpy(pre_random_input)[random_sample_inds]
+ randomized_input = TensorUtils.reshape_dimensions(
+ randomized_input,
+ begin_axis=0,
+ end_axis=0,
+ target_dims=(batch_size, self.num_crops)
+ ) # [B * N, ...] -> [B, N, ...]
+ randomized_input_np = TensorUtils.to_numpy(randomized_input[random_sample_inds])
+
+ pre_random_input_np = pre_random_input_np.transpose((0, 2, 3, 1)) # [B, C, H, W] -> [B, H, W, C]
+ randomized_input_np = randomized_input_np.transpose((0, 1, 3, 4, 2)) # [B, N, C, H, W] -> [B, N, H, W, C]
+
+ visualize_image_randomizer(
+ pre_random_input_np,
+ randomized_input_np,
+ randomizer_name='{}'.format(str(self.__class__.__name__))
+ )
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
+ self.input_shape, self.crop_height, self.crop_width, self.num_crops)
+ return msg
+
+
+class ColorRandomizer(Randomizer):
+ """
+ Randomly sample color jitter at input, and then average across color jtters at output.
+ """
+ def __init__(
+ self,
+ input_shape,
+ brightness=0.3,
+ contrast=0.3,
+ saturation=0.3,
+ hue=0.3,
+ num_samples=1,
+ ):
+ """
+ Args:
+ input_shape (tuple, list): shape of input (not including batch dimension)
+ brightness (None or float or 2-tuple): How much to jitter brightness. brightness_factor is chosen uniformly
+ from [max(0, 1 - brightness), 1 + brightness] or the given [min, max]. Should be non negative numbers.
+ contrast (None or float or 2-tuple): How much to jitter contrast. contrast_factor is chosen uniformly
+ from [max(0, 1 - contrast), 1 + contrast] or the given [min, max]. Should be non negative numbers.
+ saturation (None or float or 2-tuple): How much to jitter saturation. saturation_factor is chosen uniformly
+ from [max(0, 1 - saturation), 1 + saturation] or the given [min, max]. Should be non negative numbers.
+ hue (None or float or 2-tuple): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or
+ the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. To jitter hue, the pixel
+ values of the input image has to be non-negative for conversion to HSV space; thus it does not work
+ if you normalize your image to an interval with negative values, or use an interpolation that
+ generates negative values before using this function.
+ num_samples (int): number of random color jitters to take
+ """
+ super(ColorRandomizer, self).__init__()
+
+ assert len(input_shape) == 3 # (C, H, W)
+
+ self.input_shape = input_shape
+ self.brightness = [max(0, 1 - brightness), 1 + brightness] if type(brightness) in {float, int} else brightness
+ self.contrast = [max(0, 1 - contrast), 1 + contrast] if type(contrast) in {float, int} else contrast
+ self.saturation = [max(0, 1 - saturation), 1 + saturation] if type(saturation) in {float, int} else saturation
+ self.hue = [-hue, hue] if type(hue) in {float, int} else hue
+ self.num_samples = num_samples
+
+ @torch.jit.unused
+ def get_transform(self):
+ """
+ Get a randomized transform to be applied on image.
+
+ Implementation taken directly from:
+
+ https://github.com/pytorch/vision/blob/2f40a483d73018ae6e1488a484c5927f2b309969/torchvision/transforms/transforms.py#L1053-L1085
+
+ Returns:
+ Transform: Transform which randomly adjusts brightness, contrast and
+ saturation in a random order.
+ """
+ transforms = []
+
+ if self.brightness is not None:
+ brightness_factor = random.uniform(self.brightness[0], self.brightness[1])
+ transforms.append(Lambda(lambda img: TVF.adjust_brightness(img, brightness_factor)))
+
+ if self.contrast is not None:
+ contrast_factor = random.uniform(self.contrast[0], self.contrast[1])
+ transforms.append(Lambda(lambda img: TVF.adjust_contrast(img, contrast_factor)))
+
+ if self.saturation is not None:
+ saturation_factor = random.uniform(self.saturation[0], self.saturation[1])
+ transforms.append(Lambda(lambda img: TVF.adjust_saturation(img, saturation_factor)))
+
+ if self.hue is not None:
+ hue_factor = random.uniform(self.hue[0], self.hue[1])
+ transforms.append(Lambda(lambda img: TVF.adjust_hue(img, hue_factor)))
+
+ random.shuffle(transforms)
+ transform = Compose(transforms)
+
+ return transform
+
+ def get_batch_transform(self, N):
+ """
+ Generates a batch transform, where each set of sample(s) along the batch (first) dimension will have the same
+ @N unique ColorJitter transforms applied.
+
+ Args:
+ N (int): Number of ColorJitter transforms to apply per set of sample(s) along the batch (first) dimension
+
+ Returns:
+ Lambda: Aggregated transform which will autoamtically apply a different ColorJitter transforms to
+ each sub-set of samples along batch dimension, assumed to be the FIRST dimension in the inputted tensor
+ Note: This function will MULTIPLY the first dimension by N
+ """
+ return Lambda(lambda x: torch.stack([self.get_transform()(x_) for x_ in x for _ in range(N)]))
+
+ def output_shape_in(self, input_shape=None):
+ # outputs are same shape as inputs
+ return list(input_shape)
+
+ def output_shape_out(self, input_shape=None):
+ # since the forward_out operation splits [B * N, ...] -> [B, N, ...]
+ # and then pools to result in [B, ...], only the batch dimension changes,
+ # and so the other dimensions retain their shape.
+ return list(input_shape)
+
+ def _forward_in(self, inputs):
+ """
+ Samples N random color jitters for each input in the batch, and then reshapes
+ inputs to [B * N, ...].
+ """
+ assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
+
+ # Make sure shape is exactly 4
+ if len(inputs.shape) == 3:
+ inputs = torch.unsqueeze(inputs, dim=0)
+
+ # Create lambda to aggregate all color randomizings at once
+ transform = self.get_batch_transform(N=self.num_samples)
+
+ return transform(inputs)
+
+ def _forward_out(self, inputs):
+ """
+ Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
+ to result in shape [B, ...] to make sure the network output is consistent with
+ what would have happened if there were no randomization.
+ """
+ batch_size = (inputs.shape[0] // self.num_samples)
+ out = TensorUtils.reshape_dimensions(inputs, begin_axis=0, end_axis=0,
+ target_dims=(batch_size, self.num_samples))
+ return out.mean(dim=1)
+
+ def _visualize(self, pre_random_input, randomized_input, num_samples_to_visualize=2):
+ batch_size = pre_random_input.shape[0]
+ random_sample_inds = torch.randint(0, batch_size, size=(num_samples_to_visualize,))
+ pre_random_input_np = TensorUtils.to_numpy(pre_random_input)[random_sample_inds]
+ randomized_input = TensorUtils.reshape_dimensions(
+ randomized_input,
+ begin_axis=0,
+ end_axis=0,
+ target_dims=(batch_size, self.num_samples)
+ ) # [B * N, ...] -> [B, N, ...]
+ randomized_input_np = TensorUtils.to_numpy(randomized_input[random_sample_inds])
+
+ pre_random_input_np = pre_random_input_np.transpose((0, 2, 3, 1)) # [B, C, H, W] -> [B, H, W, C]
+ randomized_input_np = randomized_input_np.transpose((0, 1, 3, 4, 2)) # [B, N, C, H, W] -> [B, N, H, W, C]
+
+ visualize_image_randomizer(
+ pre_random_input_np,
+ randomized_input_np,
+ randomizer_name='{}'.format(str(self.__class__.__name__))
+ )
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = header + f"(input_shape={self.input_shape}, brightness={self.brightness}, contrast={self.contrast}, " \
+ f"saturation={self.saturation}, hue={self.hue}, num_samples={self.num_samples})"
+ return msg
+
+
+class GaussianNoiseRandomizer(Randomizer):
+ """
+ Randomly sample gaussian noise at input, and then average across noises at output.
+ """
+ def __init__(
+ self,
+ input_shape,
+ noise_mean=0.0,
+ noise_std=0.3,
+ limits=None,
+ num_samples=1,
+ ):
+ """
+ Args:
+ input_shape (tuple, list): shape of input (not including batch dimension)
+ noise_mean (float): Mean of noise to apply
+ noise_std (float): Standard deviation of noise to apply
+ limits (None or 2-tuple): If specified, should be the (min, max) values to clamp all noisied samples to
+ num_samples (int): number of random color jitters to take
+ """
+ super(GaussianNoiseRandomizer, self).__init__()
+
+ self.input_shape = input_shape
+ self.noise_mean = noise_mean
+ self.noise_std = noise_std
+ self.limits = limits
+ self.num_samples = num_samples
+
+ def output_shape_in(self, input_shape=None):
+ # outputs are same shape as inputs
+ return list(input_shape)
+
+ def output_shape_out(self, input_shape=None):
+ # since the forward_out operation splits [B * N, ...] -> [B, N, ...]
+ # and then pools to result in [B, ...], only the batch dimension changes,
+ # and so the other dimensions retain their shape.
+ return list(input_shape)
+
+ def _forward_in(self, inputs):
+ """
+ Samples N random gaussian noises for each input in the batch, and then reshapes
+ inputs to [B * N, ...].
+ """
+ out = TensorUtils.repeat_by_expand_at(inputs, repeats=self.num_samples, dim=0)
+
+ # Sample noise across all samples
+ out = torch.rand(size=out.shape).to(inputs.device) * self.noise_std + self.noise_mean + out
+
+ # Possibly clamp
+ if self.limits is not None:
+ out = torch.clip(out, min=self.limits[0], max=self.limits[1])
+
+ return out
+
+ def _forward_out(self, inputs):
+ """
+ Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
+ to result in shape [B, ...] to make sure the network output is consistent with
+ what would have happened if there were no randomization.
+ """
+ batch_size = (inputs.shape[0] // self.num_samples)
+ out = TensorUtils.reshape_dimensions(inputs, begin_axis=0, end_axis=0,
+ target_dims=(batch_size, self.num_samples))
+ return out.mean(dim=1)
+
+ def _visualize(self, pre_random_input, randomized_input, num_samples_to_visualize=2):
+ batch_size = pre_random_input.shape[0]
+ random_sample_inds = torch.randint(0, batch_size, size=(num_samples_to_visualize,))
+ pre_random_input_np = TensorUtils.to_numpy(pre_random_input)[random_sample_inds]
+ randomized_input = TensorUtils.reshape_dimensions(
+ randomized_input,
+ begin_axis=0,
+ end_axis=0,
+ target_dims=(batch_size, self.num_samples)
+ ) # [B * N, ...] -> [B, N, ...]
+ randomized_input_np = TensorUtils.to_numpy(randomized_input[random_sample_inds])
+
+ pre_random_input_np = pre_random_input_np.transpose((0, 2, 3, 1)) # [B, C, H, W] -> [B, H, W, C]
+ randomized_input_np = randomized_input_np.transpose((0, 1, 3, 4, 2)) # [B, N, C, H, W] -> [B, N, H, W, C]
+
+ visualize_image_randomizer(
+ pre_random_input_np,
+ randomized_input_np,
+ randomizer_name='{}'.format(str(self.__class__.__name__))
+ )
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = header + f"(input_shape={self.input_shape}, noise_mean={self.noise_mean}, noise_std={self.noise_std}, " \
+ f"limits={self.limits}, num_samples={self.num_samples})"
+ return msg
diff --git a/phantom/submodules/phantom-robomimic/robomimic/models/obs_nets.py b/phantom/submodules/phantom-robomimic/robomimic/models/obs_nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..b328418505d4aedefcf43b0c3cbd6dd87ae05c37
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/models/obs_nets.py
@@ -0,0 +1,1099 @@
+"""
+Contains torch Modules that help deal with inputs consisting of multiple
+modalities. This is extremely common when networks must deal with one or
+more observation dictionaries, where each input dictionary can have
+observation keys of a certain modality and shape.
+
+As an example, an observation could consist of a flat "robot0_eef_pos" observation key,
+and a 3-channel RGB "agentview_image" observation key.
+"""
+import sys
+import numpy as np
+import textwrap
+from copy import deepcopy
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributions as D
+
+from robomimic.utils.python_utils import extract_class_init_kwargs_from_dict
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.obs_utils as ObsUtils
+from robomimic.models.base_nets import Module, Sequential, MLP, RNN_Base, ResNet18Conv, SpatialSoftmax, \
+ FeatureAggregator
+from robomimic.models.obs_core import VisualCore, Randomizer
+from robomimic.models.transformers import PositionalEncoding, GPT_Backbone
+
+
+def obs_encoder_factory(
+ obs_shapes,
+ feature_activation=nn.ReLU,
+ encoder_kwargs=None,
+ ):
+ """
+ Utility function to create an @ObservationEncoder from kwargs specified in config.
+
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps observation key to
+ expected shapes for observations.
+
+ feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass
+ None to apply no activation.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be
+ nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ enc = ObservationEncoder(feature_activation=feature_activation)
+ for k, obs_shape in obs_shapes.items():
+ obs_modality = ObsUtils.OBS_KEYS_TO_MODALITIES[k]
+ enc_kwargs = deepcopy(ObsUtils.DEFAULT_ENCODER_KWARGS[obs_modality]) if encoder_kwargs is None else \
+ deepcopy(encoder_kwargs[obs_modality])
+
+ for obs_module, cls_mapping in zip(("core", "obs_randomizer"),
+ (ObsUtils.OBS_ENCODER_CORES, ObsUtils.OBS_RANDOMIZERS)):
+ # Sanity check for kwargs in case they don't exist / are None
+ if enc_kwargs.get(f"{obs_module}_kwargs", None) is None:
+ enc_kwargs[f"{obs_module}_kwargs"] = {}
+ # Add in input shape info
+ enc_kwargs[f"{obs_module}_kwargs"]["input_shape"] = obs_shape
+ # If group class is specified, then make sure corresponding kwargs only contain relevant kwargs
+ if enc_kwargs[f"{obs_module}_class"] is not None:
+ enc_kwargs[f"{obs_module}_kwargs"] = extract_class_init_kwargs_from_dict(
+ cls=cls_mapping[enc_kwargs[f"{obs_module}_class"]],
+ dic=enc_kwargs[f"{obs_module}_kwargs"],
+ copy=False,
+ )
+
+ # Add in input shape info
+ randomizer = None if enc_kwargs["obs_randomizer_class"] is None else \
+ ObsUtils.OBS_RANDOMIZERS[enc_kwargs["obs_randomizer_class"]](**enc_kwargs["obs_randomizer_kwargs"])
+
+ enc.register_obs_key(
+ name=k,
+ shape=obs_shape,
+ net_class=enc_kwargs["core_class"],
+ net_kwargs=enc_kwargs["core_kwargs"],
+ randomizer=randomizer,
+ )
+
+ enc.make()
+ return enc
+
+
+class ObservationEncoder(Module):
+ """
+ Module that processes inputs by observation key and then concatenates the processed
+ observation keys together. Each key is processed with an encoder head network.
+ Call @register_obs_key to register observation keys with the encoder and then
+ finally call @make to create the encoder networks.
+ """
+ def __init__(self, feature_activation=nn.ReLU):
+ """
+ Args:
+ feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass
+ None to apply no activation.
+ """
+ super(ObservationEncoder, self).__init__()
+ self.obs_shapes = OrderedDict()
+ self.obs_nets_classes = OrderedDict()
+ self.obs_nets_kwargs = OrderedDict()
+ self.obs_share_mods = OrderedDict()
+ self.obs_nets = nn.ModuleDict()
+ self.obs_randomizers = nn.ModuleDict()
+ self.feature_activation = feature_activation
+ self._locked = False
+
+ def register_obs_key(
+ self,
+ name,
+ shape,
+ net_class=None,
+ net_kwargs=None,
+ net=None,
+ randomizer=None,
+ share_net_from=None,
+ ):
+ """
+ Register an observation key that this encoder should be responsible for.
+
+ Args:
+ name (str): modality name
+ shape (int tuple): shape of modality
+ net_class (str): name of class in base_nets.py that should be used
+ to process this observation key before concatenation. Pass None to flatten
+ and concatenate the observation key directly.
+ net_kwargs (dict): arguments to pass to @net_class
+ net (Module instance): if provided, use this Module to process the observation key
+ instead of creating a different net
+ randomizer (Randomizer instance): if provided, use this Module to augment observation keys
+ coming in to the encoder, and possibly augment the processed output as well
+ share_net_from (str): if provided, use the same instance of @net_class
+ as another observation key. This observation key must already exist in this encoder.
+ Warning: Note that this does not share the observation key randomizer
+ """
+ assert not self._locked, "ObservationEncoder: @register_obs_key called after @make"
+ assert name not in self.obs_shapes, "ObservationEncoder: modality {} already exists".format(name)
+
+ if net is not None:
+ assert isinstance(net, Module), "ObservationEncoder: @net must be instance of Module class"
+ assert (net_class is None) and (net_kwargs is None) and (share_net_from is None), \
+ "ObservationEncoder: @net provided - ignore other net creation options"
+
+ if share_net_from is not None:
+ # share processing with another modality
+ assert (net_class is None) and (net_kwargs is None)
+ assert share_net_from in self.obs_shapes
+
+ net_kwargs = deepcopy(net_kwargs) if net_kwargs is not None else {}
+ if randomizer is not None:
+ assert isinstance(randomizer, Randomizer)
+ if net_kwargs is not None:
+ # update input shape to visual core
+ net_kwargs["input_shape"] = randomizer.output_shape_in(shape)
+
+ self.obs_shapes[name] = shape
+ self.obs_nets_classes[name] = net_class
+ self.obs_nets_kwargs[name] = net_kwargs
+ self.obs_nets[name] = net
+ self.obs_randomizers[name] = randomizer
+ self.obs_share_mods[name] = share_net_from
+
+ def make(self):
+ """
+ Creates the encoder networks and locks the encoder so that more modalities cannot be added.
+ """
+ assert not self._locked, "ObservationEncoder: @make called more than once"
+ self._create_layers()
+ self._locked = True
+
+ def _create_layers(self):
+ """
+ Creates all networks and layers required by this encoder using the registered modalities.
+ """
+ assert not self._locked, "ObservationEncoder: layers have already been created"
+
+ for k in self.obs_shapes:
+ if self.obs_nets_classes[k] is not None:
+ # create net to process this modality
+ self.obs_nets[k] = ObsUtils.OBS_ENCODER_CORES[self.obs_nets_classes[k]](**self.obs_nets_kwargs[k])
+ elif self.obs_share_mods[k] is not None:
+ # make sure net is shared with another modality
+ self.obs_nets[k] = self.obs_nets[self.obs_share_mods[k]]
+
+ self.activation = None
+ if self.feature_activation is not None:
+ self.activation = self.feature_activation()
+
+ def forward(self, obs_dict):
+ """
+ Processes modalities according to the ordering in @self.obs_shapes. For each
+ modality, it is processed with a randomizer (if present), an encoder
+ network (if present), and again with the randomizer (if present), flattened,
+ and then concatenated with the other processed modalities.
+
+ Args:
+ obs_dict (OrderedDict): dictionary that maps modalities to torch.Tensor
+ batches that agree with @self.obs_shapes. All modalities in
+ @self.obs_shapes must be present, but additional modalities
+ can also be present.
+
+ Returns:
+ feats (torch.Tensor): flat features of shape [B, D]
+ """
+ assert self._locked, "ObservationEncoder: @make has not been called yet"
+
+ # ensure all modalities that the encoder handles are present
+ assert set(self.obs_shapes.keys()).issubset(obs_dict), "ObservationEncoder: {} does not contain all modalities {}".format(
+ list(obs_dict.keys()), list(self.obs_shapes.keys())
+ )
+
+ # process modalities by order given by @self.obs_shapes
+ feats = []
+ for k in self.obs_shapes:
+ x = obs_dict[k]
+ # maybe process encoder input with randomizer
+ if self.obs_randomizers[k] is not None:
+ x = self.obs_randomizers[k].forward_in(x)
+ # maybe process with obs net
+ if self.obs_nets[k] is not None:
+ x = self.obs_nets[k](x)
+ if self.activation is not None:
+ x = self.activation(x)
+ # maybe process encoder output with randomizer
+ if self.obs_randomizers[k] is not None:
+ x = self.obs_randomizers[k].forward_out(x)
+ # flatten to [B, D]
+ x = TensorUtils.flatten(x, begin_axis=1)
+ feats.append(x)
+
+ # concatenate all features together
+ return torch.cat(feats, dim=-1)
+
+ def output_shape(self, input_shape=None):
+ """
+ Compute the output shape of the encoder.
+ """
+ feat_dim = 0
+ for k in self.obs_shapes:
+ feat_shape = self.obs_shapes[k]
+ if self.obs_randomizers[k] is not None:
+ feat_shape = self.obs_randomizers[k].output_shape_in(feat_shape)
+ if self.obs_nets[k] is not None:
+ feat_shape = self.obs_nets[k].output_shape(feat_shape)
+ if self.obs_randomizers[k] is not None:
+ feat_shape = self.obs_randomizers[k].output_shape_out(feat_shape)
+ feat_dim += int(np.prod(feat_shape))
+ return [feat_dim]
+
+ def __repr__(self):
+ """
+ Pretty print the encoder.
+ """
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ for k in self.obs_shapes:
+ msg += textwrap.indent('\nKey(\n', ' ' * 4)
+ indent = ' ' * 8
+ msg += textwrap.indent("name={}\nshape={}\n".format(k, self.obs_shapes[k]), indent)
+ msg += textwrap.indent("modality={}\n".format(ObsUtils.OBS_KEYS_TO_MODALITIES[k]), indent)
+ msg += textwrap.indent("randomizer={}\n".format(self.obs_randomizers[k]), indent)
+ msg += textwrap.indent("net={}\n".format(self.obs_nets[k]), indent)
+ msg += textwrap.indent("sharing_from={}\n".format(self.obs_share_mods[k]), indent)
+ msg += textwrap.indent(")", ' ' * 4)
+ msg += textwrap.indent("\noutput_shape={}".format(self.output_shape()), ' ' * 4)
+ msg = header + '(' + msg + '\n)'
+ return msg
+
+
+class ObservationDecoder(Module):
+ """
+ Module that can generate observation outputs by modality. Inputs are assumed
+ to be flat (usually outputs from some hidden layer). Each observation output
+ is generated with a linear layer from these flat inputs. Subclass this
+ module in order to implement more complex schemes for generating each
+ modality.
+ """
+ def __init__(
+ self,
+ decode_shapes,
+ input_feat_dim,
+ ):
+ """
+ Args:
+ decode_shapes (OrderedDict): a dictionary that maps observation key to
+ expected shape. This is used to generate output modalities from the
+ input features.
+
+ input_feat_dim (int): flat input dimension size
+ """
+ super(ObservationDecoder, self).__init__()
+
+ # important: sort observation keys to ensure consistent ordering of modalities
+ assert isinstance(decode_shapes, OrderedDict)
+ self.obs_shapes = OrderedDict()
+ for k in decode_shapes:
+ self.obs_shapes[k] = decode_shapes[k]
+
+ self.input_feat_dim = input_feat_dim
+ self._create_layers()
+
+ def _create_layers(self):
+ """
+ Create a linear layer to predict each modality.
+ """
+ self.nets = nn.ModuleDict()
+ for k in self.obs_shapes:
+ layer_out_dim = int(np.prod(self.obs_shapes[k]))
+ self.nets[k] = nn.Linear(self.input_feat_dim, layer_out_dim)
+
+ def output_shape(self, input_shape=None):
+ """
+ Returns output shape for this module, which is a dictionary instead
+ of a list since outputs are dictionaries.
+ """
+ return { k : list(self.obs_shapes[k]) for k in self.obs_shapes }
+
+ def forward(self, feats):
+ """
+ Predict each modality from input features, and reshape to each modality's shape.
+ """
+ output = {}
+ for k in self.obs_shapes:
+ out = self.nets[k](feats)
+ output[k] = out.reshape(-1, *self.obs_shapes[k])
+ return output
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ for k in self.obs_shapes:
+ msg += textwrap.indent('\nKey(\n', ' ' * 4)
+ indent = ' ' * 8
+ msg += textwrap.indent("name={}\nshape={}\n".format(k, self.obs_shapes[k]), indent)
+ msg += textwrap.indent("modality={}\n".format(ObsUtils.OBS_KEYS_TO_MODALITIES[k]), indent)
+ msg += textwrap.indent("net=({})\n".format(self.nets[k]), indent)
+ msg += textwrap.indent(")", ' ' * 4)
+ msg = header + '(' + msg + '\n)'
+ return msg
+
+
+class ObservationGroupEncoder(Module):
+ """
+ This class allows networks to encode multiple observation dictionaries into a single
+ flat, concatenated vector representation. It does this by assigning each observation
+ dictionary (observation group) an @ObservationEncoder object.
+
+ The class takes a dictionary of dictionaries, @observation_group_shapes.
+ Each key corresponds to a observation group (e.g. 'obs', 'subgoal', 'goal')
+ and each OrderedDict should be a map between modalities and
+ expected input shapes (e.g. { 'image' : (3, 120, 160) }).
+ """
+ def __init__(
+ self,
+ observation_group_shapes,
+ feature_activation=nn.ReLU,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ observation_group_shapes (OrderedDict): a dictionary of dictionaries.
+ Each key in this dictionary should specify an observation group, and
+ the value should be an OrderedDict that maps modalities to
+ expected shapes.
+
+ feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass
+ None to apply no activation.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ super(ObservationGroupEncoder, self).__init__()
+
+ # type checking
+ assert isinstance(observation_group_shapes, OrderedDict)
+ assert np.all([isinstance(observation_group_shapes[k], OrderedDict) for k in observation_group_shapes])
+
+ self.observation_group_shapes = observation_group_shapes
+
+ # create an observation encoder per observation group
+ self.nets = nn.ModuleDict()
+ for obs_group in self.observation_group_shapes:
+ self.nets[obs_group] = obs_encoder_factory(
+ obs_shapes=self.observation_group_shapes[obs_group],
+ feature_activation=feature_activation,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def forward(self, **inputs):
+ """
+ Process each set of inputs in its own observation group.
+
+ Args:
+ inputs (dict): dictionary that maps observation groups to observation
+ dictionaries of torch.Tensor batches that agree with
+ @self.observation_group_shapes. All observation groups in
+ @self.observation_group_shapes must be present, but additional
+ observation groups can also be present. Note that these are specified
+ as kwargs for ease of use with networks that name each observation
+ stream in their forward calls.
+
+ Returns:
+ outputs (torch.Tensor): flat outputs of shape [B, D]
+ """
+
+ # ensure all observation groups we need are present
+ assert set(self.observation_group_shapes.keys()).issubset(inputs), "{} does not contain all observation groups {}".format(
+ list(inputs.keys()), list(self.observation_group_shapes.keys())
+ )
+
+ outputs = []
+ # Deterministic order since self.observation_group_shapes is OrderedDict
+ for obs_group in self.observation_group_shapes:
+ # pass through encoder
+ outputs.append(
+ self.nets[obs_group].forward(inputs[obs_group])
+ )
+
+ return torch.cat(outputs, dim=-1)
+
+ def output_shape(self):
+ """
+ Compute the output shape of this encoder.
+ """
+ feat_dim = 0
+ for obs_group in self.observation_group_shapes:
+ # get feature dimension of these keys
+ feat_dim += self.nets[obs_group].output_shape()[0]
+ return [feat_dim]
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ for k in self.observation_group_shapes:
+ msg += '\n'
+ indent = ' ' * 4
+ msg += textwrap.indent("group={}\n{}".format(k, self.nets[k]), indent)
+ msg = header + '(' + msg + '\n)'
+ return msg
+
+
+class MIMO_MLP(Module):
+ """
+ Extension to MLP to accept multiple observation dictionaries as input and
+ to output dictionaries of tensors. Inputs are specified as a dictionary of
+ observation dictionaries, with each key corresponding to an observation group.
+
+ This module utilizes @ObservationGroupEncoder to process the multiple input dictionaries and
+ @ObservationDecoder to generate tensor dictionaries. The default behavior
+ for encoding the inputs is to process visual inputs with a learned CNN and concatenating
+ the flat encodings with the other flat inputs. The default behavior for generating
+ outputs is to use a linear layer branch to produce each modality separately
+ (including visual outputs).
+ """
+ def __init__(
+ self,
+ input_obs_group_shapes,
+ output_shapes,
+ layer_dims,
+ layer_func=nn.Linear,
+ activation=nn.ReLU,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ input_obs_group_shapes (OrderedDict): a dictionary of dictionaries.
+ Each key in this dictionary should specify an observation group, and
+ the value should be an OrderedDict that maps modalities to
+ expected shapes.
+
+ output_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for outputs.
+
+ layer_dims ([int]): sequence of integers for the MLP hidden layer sizes
+
+ layer_func: mapping per MLP layer - defaults to Linear
+
+ activation: non-linearity per MLP layer - defaults to ReLU
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ super(MIMO_MLP, self).__init__()
+
+ assert isinstance(input_obs_group_shapes, OrderedDict)
+ assert np.all([isinstance(input_obs_group_shapes[k], OrderedDict) for k in input_obs_group_shapes])
+ assert isinstance(output_shapes, OrderedDict)
+
+ self.input_obs_group_shapes = input_obs_group_shapes
+ self.output_shapes = output_shapes
+
+ self.nets = nn.ModuleDict()
+
+ # Encoder for all observation groups.
+ self.nets["encoder"] = ObservationGroupEncoder(
+ observation_group_shapes=input_obs_group_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ # flat encoder output dimension
+ mlp_input_dim = self.nets["encoder"].output_shape()[0]
+
+ # intermediate MLP layers
+ self.nets["mlp"] = MLP(
+ input_dim=mlp_input_dim,
+ output_dim=layer_dims[-1],
+ layer_dims=layer_dims[:-1],
+ layer_func=layer_func,
+ activation=activation,
+ output_activation=activation, # make sure non-linearity is applied before decoder
+ )
+
+ # decoder for output modalities
+ self.nets["decoder"] = ObservationDecoder(
+ decode_shapes=self.output_shapes,
+ input_feat_dim=layer_dims[-1],
+ )
+
+ def output_shape(self, input_shape=None):
+ """
+ Returns output shape for this module, which is a dictionary instead
+ of a list since outputs are dictionaries.
+ """
+ return { k : list(self.output_shapes[k]) for k in self.output_shapes }
+
+ def forward(self, **inputs):
+ """
+ Process each set of inputs in its own observation group.
+
+ Args:
+ inputs (dict): a dictionary of dictionaries with one dictionary per
+ observation group. Each observation group's dictionary should map
+ modality to torch.Tensor batches. Should be consistent with
+ @self.input_obs_group_shapes.
+
+ Returns:
+ outputs (dict): dictionary of output torch.Tensors, that corresponds
+ to @self.output_shapes
+ """
+ enc_outputs = self.nets["encoder"](**inputs)
+ mlp_out = self.nets["mlp"](enc_outputs)
+ return self.nets["decoder"](mlp_out)
+
+ def _to_string(self):
+ """
+ Subclasses should override this method to print out info about network / policy.
+ """
+ return ''
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ indent = ' ' * 4
+ if self._to_string() != '':
+ msg += textwrap.indent("\n" + self._to_string() + "\n", indent)
+ msg += textwrap.indent("\nencoder={}".format(self.nets["encoder"]), indent)
+ msg += textwrap.indent("\n\nmlp={}".format(self.nets["mlp"]), indent)
+ msg += textwrap.indent("\n\ndecoder={}".format(self.nets["decoder"]), indent)
+ msg = header + '(' + msg + '\n)'
+ return msg
+
+
+class RNN_MIMO_MLP(Module):
+ """
+ A wrapper class for a multi-step RNN and a per-step MLP and a decoder.
+
+ Structure: [encoder -> rnn -> mlp -> decoder]
+
+ All temporal inputs are processed by a shared @ObservationGroupEncoder,
+ followed by an RNN, and then a per-step multi-output MLP.
+ """
+ def __init__(
+ self,
+ input_obs_group_shapes,
+ output_shapes,
+ mlp_layer_dims,
+ rnn_hidden_dim,
+ rnn_num_layers,
+ rnn_type="LSTM", # [LSTM, GRU]
+ rnn_kwargs=None,
+ mlp_activation=nn.ReLU,
+ mlp_layer_func=nn.Linear,
+ per_step=True,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ input_obs_group_shapes (OrderedDict): a dictionary of dictionaries.
+ Each key in this dictionary should specify an observation group, and
+ the value should be an OrderedDict that maps modalities to
+ expected shapes.
+
+ output_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for outputs.
+
+ rnn_hidden_dim (int): RNN hidden dimension
+
+ rnn_num_layers (int): number of RNN layers
+
+ rnn_type (str): [LSTM, GRU]
+
+ rnn_kwargs (dict): kwargs for the rnn model
+
+ per_step (bool): if True, apply the MLP and observation decoder into @output_shapes
+ at every step of the RNN. Otherwise, apply them to the final hidden state of the
+ RNN.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ super(RNN_MIMO_MLP, self).__init__()
+ assert isinstance(input_obs_group_shapes, OrderedDict)
+ assert np.all([isinstance(input_obs_group_shapes[k], OrderedDict) for k in input_obs_group_shapes])
+ assert isinstance(output_shapes, OrderedDict)
+ self.input_obs_group_shapes = input_obs_group_shapes
+ self.output_shapes = output_shapes
+ self.per_step = per_step
+
+ self.nets = nn.ModuleDict()
+
+ # Encoder for all observation groups.
+ self.nets["encoder"] = ObservationGroupEncoder(
+ observation_group_shapes=input_obs_group_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ # flat encoder output dimension
+ rnn_input_dim = self.nets["encoder"].output_shape()[0]
+
+ # bidirectional RNNs mean that the output of RNN will be twice the hidden dimension
+ rnn_is_bidirectional = rnn_kwargs.get("bidirectional", False)
+ num_directions = int(rnn_is_bidirectional) + 1 # 2 if bidirectional, 1 otherwise
+ rnn_output_dim = num_directions * rnn_hidden_dim
+
+ per_step_net = None
+ self._has_mlp = (len(mlp_layer_dims) > 0)
+ if self._has_mlp:
+ self.nets["mlp"] = MLP(
+ input_dim=rnn_output_dim,
+ output_dim=mlp_layer_dims[-1],
+ layer_dims=mlp_layer_dims[:-1],
+ output_activation=mlp_activation,
+ layer_func=mlp_layer_func
+ )
+ self.nets["decoder"] = ObservationDecoder(
+ decode_shapes=self.output_shapes,
+ input_feat_dim=mlp_layer_dims[-1],
+ )
+ if self.per_step:
+ per_step_net = Sequential(self.nets["mlp"], self.nets["decoder"])
+ else:
+ self.nets["decoder"] = ObservationDecoder(
+ decode_shapes=self.output_shapes,
+ input_feat_dim=rnn_output_dim,
+ )
+ if self.per_step:
+ per_step_net = self.nets["decoder"]
+
+ # core network
+ self.nets["rnn"] = RNN_Base(
+ input_dim=rnn_input_dim,
+ rnn_hidden_dim=rnn_hidden_dim,
+ rnn_num_layers=rnn_num_layers,
+ rnn_type=rnn_type,
+ per_step_net=per_step_net,
+ rnn_kwargs=rnn_kwargs
+ )
+
+ def get_rnn_init_state(self, batch_size, device):
+ """
+ Get a default RNN state (zeros)
+
+ Args:
+ batch_size (int): batch size dimension
+
+ device: device the hidden state should be sent to.
+
+ Returns:
+ hidden_state (torch.Tensor or tuple): returns hidden state tensor or tuple of hidden state tensors
+ depending on the RNN type
+ """
+ return self.nets["rnn"].get_rnn_init_state(batch_size, device=device)
+
+ def output_shape(self, input_shape):
+ """
+ Returns output shape for this module, which is a dictionary instead
+ of a list since outputs are dictionaries.
+
+ Args:
+ input_shape (dict): dictionary of dictionaries, where each top-level key
+ corresponds to an observation group, and the low-level dictionaries
+ specify the shape for each modality in an observation dictionary
+ """
+
+ # infers temporal dimension from input shape
+ obs_group = list(self.input_obs_group_shapes.keys())[0]
+ mod = list(self.input_obs_group_shapes[obs_group].keys())[0]
+ T = input_shape[obs_group][mod][0]
+ TensorUtils.assert_size_at_dim(input_shape, size=T, dim=0,
+ msg="RNN_MIMO_MLP: input_shape inconsistent in temporal dimension")
+ # returns a dictionary instead of list since outputs are dictionaries
+ return { k : [T] + list(self.output_shapes[k]) for k in self.output_shapes }
+
+ def forward(self, rnn_init_state=None, return_state=False, **inputs):
+ """
+ Args:
+ inputs (dict): a dictionary of dictionaries with one dictionary per
+ observation group. Each observation group's dictionary should map
+ modality to torch.Tensor batches. Should be consistent with
+ @self.input_obs_group_shapes. First two leading dimensions should
+ be batch and time [B, T, ...] for each tensor.
+
+ rnn_init_state: rnn hidden state, initialize to zero state if set to None
+
+ return_state (bool): whether to return hidden state
+
+ Returns:
+ outputs (dict): dictionary of output torch.Tensors, that corresponds
+ to @self.output_shapes. Leading dimensions will be batch and time [B, T, ...]
+ for each tensor.
+
+ rnn_state (torch.Tensor or tuple): return the new rnn state (if @return_state)
+ """
+ for obs_group in self.input_obs_group_shapes:
+ for k in self.input_obs_group_shapes[obs_group]:
+ # first two dimensions should be [B, T] for inputs
+ assert inputs[obs_group][k].ndim - 2 == len(self.input_obs_group_shapes[obs_group][k])
+
+ # use encoder to extract flat rnn inputs
+ rnn_inputs = TensorUtils.time_distributed(inputs, self.nets["encoder"], inputs_as_kwargs=True)
+ assert rnn_inputs.ndim == 3 # [B, T, D]
+ if self.per_step:
+ return self.nets["rnn"].forward(inputs=rnn_inputs, rnn_init_state=rnn_init_state, return_state=return_state)
+
+ # apply MLP + decoder to last RNN output
+ outputs = self.nets["rnn"].forward(inputs=rnn_inputs, rnn_init_state=rnn_init_state, return_state=return_state)
+ if return_state:
+ outputs, rnn_state = outputs
+
+ assert outputs.ndim == 3 # [B, T, D]
+ if self._has_mlp:
+ outputs = self.nets["decoder"](self.nets["mlp"](outputs[:, -1]))
+ else:
+ outputs = self.nets["decoder"](outputs[:, -1])
+
+ if return_state:
+ return outputs, rnn_state
+ return outputs
+
+ def forward_step(self, rnn_state, **inputs):
+ """
+ Unroll network over a single timestep.
+
+ Args:
+ inputs (dict): expects same modalities as @self.input_shapes, with
+ additional batch dimension (but NOT time), since this is a
+ single time step.
+
+ rnn_state (torch.Tensor): rnn hidden state
+
+ Returns:
+ outputs (dict): dictionary of output torch.Tensors, that corresponds
+ to @self.output_shapes. Does not contain time dimension.
+
+ rnn_state: return the new rnn state
+ """
+ # ensure that the only extra dimension is batch dim, not temporal dim
+ assert np.all([inputs[k].ndim - 1 == len(self.input_shapes[k]) for k in self.input_shapes])
+
+ inputs = TensorUtils.to_sequence(inputs)
+ outputs, rnn_state = self.forward(
+ inputs,
+ rnn_init_state=rnn_state,
+ return_state=True,
+ )
+ if self.per_step:
+ # if outputs are not per-step, the time dimension is already reduced
+ outputs = outputs[:, 0]
+ return outputs, rnn_state
+
+ def _to_string(self):
+ """
+ Subclasses should override this method to print out info about network / policy.
+ """
+ return ''
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ indent = ' ' * 4
+ msg += textwrap.indent("\n" + self._to_string(), indent)
+ msg += textwrap.indent("\n\nencoder={}".format(self.nets["encoder"]), indent)
+ msg += textwrap.indent("\n\nrnn={}".format(self.nets["rnn"]), indent)
+ msg = header + '(' + msg + '\n)'
+ return msg
+
+
+class MIMO_Transformer(Module):
+ """
+ Extension to Transformer (based on GPT architecture) to accept multiple observation
+ dictionaries as input and to output dictionaries of tensors. Inputs are specified as
+ a dictionary of observation dictionaries, with each key corresponding to an observation group.
+ This module utilizes @ObservationGroupEncoder to process the multiple input dictionaries and
+ @ObservationDecoder to generate tensor dictionaries. The default behavior
+ for encoding the inputs is to process visual inputs with a learned CNN and concatenating
+ the flat encodings with the other flat inputs. The default behavior for generating
+ outputs is to use a linear layer branch to produce each modality separately
+ (including visual outputs).
+ """
+ def __init__(
+ self,
+ input_obs_group_shapes,
+ output_shapes,
+ transformer_embed_dim,
+ transformer_num_layers,
+ transformer_num_heads,
+ transformer_context_length,
+ transformer_emb_dropout=0.1,
+ transformer_attn_dropout=0.1,
+ transformer_block_output_dropout=0.1,
+ transformer_sinusoidal_embedding=False,
+ transformer_activation="gelu",
+ transformer_nn_parameter_for_timesteps=False,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ input_obs_group_shapes (OrderedDict): a dictionary of dictionaries.
+ Each key in this dictionary should specify an observation group, and
+ the value should be an OrderedDict that maps modalities to
+ expected shapes.
+ output_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for outputs.
+ transformer_embed_dim (int): dimension for embeddings used by transformer
+ transformer_num_layers (int): number of transformer blocks to stack
+ transformer_num_heads (int): number of attention heads for each
+ transformer block - must divide @transformer_embed_dim evenly. Self-attention is
+ computed over this many partitions of the embedding dimension separately.
+ transformer_context_length (int): expected length of input sequences
+ transformer_activation: non-linearity for input and output layers used in transformer
+ transformer_emb_dropout (float): dropout probability for embedding inputs in transformer
+ transformer_attn_dropout (float): dropout probability for attention outputs for each transformer block
+ transformer_block_output_dropout (float): dropout probability for final outputs for each transformer block
+ encoder_kwargs (dict): observation encoder config
+ """
+ super(MIMO_Transformer, self).__init__()
+
+ assert isinstance(input_obs_group_shapes, OrderedDict)
+ assert np.all([isinstance(input_obs_group_shapes[k], OrderedDict) for k in input_obs_group_shapes])
+ assert isinstance(output_shapes, OrderedDict)
+
+ self.input_obs_group_shapes = input_obs_group_shapes
+ self.output_shapes = output_shapes
+
+ self.nets = nn.ModuleDict()
+ self.params = nn.ParameterDict()
+
+ # Encoder for all observation groups.
+ self.nets["encoder"] = ObservationGroupEncoder(
+ observation_group_shapes=input_obs_group_shapes,
+ encoder_kwargs=encoder_kwargs,
+ feature_activation=None,
+ )
+
+ # flat encoder output dimension
+ transformer_input_dim = self.nets["encoder"].output_shape()[0]
+
+ self.nets["embed_encoder"] = nn.Linear(
+ transformer_input_dim, transformer_embed_dim
+ )
+
+ max_timestep = transformer_context_length
+
+ if transformer_sinusoidal_embedding:
+ self.nets["embed_timestep"] = PositionalEncoding(transformer_embed_dim)
+ elif transformer_nn_parameter_for_timesteps:
+ assert (
+ not transformer_sinusoidal_embedding
+ ), "nn.Parameter only works with learned embeddings"
+ self.params["embed_timestep"] = nn.Parameter(
+ torch.zeros(1, max_timestep, transformer_embed_dim)
+ )
+ else:
+ self.nets["embed_timestep"] = nn.Embedding(max_timestep, transformer_embed_dim)
+
+ # layer norm for embeddings
+ self.nets["embed_ln"] = nn.LayerNorm(transformer_embed_dim)
+
+ # dropout for input embeddings
+ self.nets["embed_drop"] = nn.Dropout(transformer_emb_dropout)
+
+ # GPT transformer
+ self.nets["transformer"] = GPT_Backbone(
+ embed_dim=transformer_embed_dim,
+ num_layers=transformer_num_layers,
+ num_heads=transformer_num_heads,
+ context_length=transformer_context_length,
+ attn_dropout=transformer_attn_dropout,
+ block_output_dropout=transformer_block_output_dropout,
+ activation=transformer_activation,
+ )
+
+ # decoder for output modalities
+ self.nets["decoder"] = ObservationDecoder(
+ decode_shapes=self.output_shapes,
+ input_feat_dim=transformer_embed_dim,
+ )
+
+ self.transformer_context_length = transformer_context_length
+ self.transformer_embed_dim = transformer_embed_dim
+ self.transformer_sinusoidal_embedding = transformer_sinusoidal_embedding
+ self.transformer_nn_parameter_for_timesteps = transformer_nn_parameter_for_timesteps
+
+ def output_shape(self, input_shape=None):
+ """
+ Returns output shape for this module, which is a dictionary instead
+ of a list since outputs are dictionaries.
+ """
+ return { k : list(self.output_shapes[k]) for k in self.output_shapes }
+
+ def embed_timesteps(self, embeddings):
+ """
+ Computes timestep-based embeddings (aka positional embeddings) to add to embeddings.
+ Args:
+ embeddings (torch.Tensor): embeddings prior to positional embeddings are computed
+ Returns:
+ time_embeddings (torch.Tensor): positional embeddings to add to embeddings
+ """
+ timesteps = (
+ torch.arange(
+ 0,
+ embeddings.shape[1],
+ dtype=embeddings.dtype,
+ device=embeddings.device,
+ )
+ .unsqueeze(0)
+ .repeat(embeddings.shape[0], 1)
+ )
+ assert (timesteps >= 0.0).all(), "timesteps must be positive!"
+ if self.transformer_sinusoidal_embedding:
+ assert torch.is_floating_point(timesteps), timesteps.dtype
+ else:
+ timesteps = timesteps.long()
+
+ if self.transformer_nn_parameter_for_timesteps:
+ time_embeddings = self.params["embed_timestep"]
+ else:
+ time_embeddings = self.nets["embed_timestep"](
+ timesteps
+ ) # these are NOT fed into transformer, only added to the inputs.
+ # compute how many modalities were combined into embeddings, replicate time embeddings that many times
+ num_replicates = embeddings.shape[-1] // self.transformer_embed_dim
+ time_embeddings = torch.cat([time_embeddings for _ in range(num_replicates)], -1)
+ assert (
+ embeddings.shape == time_embeddings.shape
+ ), f"{embeddings.shape}, {time_embeddings.shape}"
+ return time_embeddings
+
+ def input_embedding(
+ self,
+ inputs,
+ ):
+ """
+ Process encoded observations into embeddings to pass to transformer,
+ Adds timestep-based embeddings (aka positional embeddings) to inputs.
+ Args:
+ inputs (torch.Tensor): outputs from observation encoder
+ Returns:
+ embeddings (torch.Tensor): input embeddings to pass to transformer backbone.
+ """
+ embeddings = self.nets["embed_encoder"](inputs)
+ time_embeddings = self.embed_timesteps(embeddings)
+ embeddings = embeddings + time_embeddings
+ embeddings = self.nets["embed_ln"](embeddings)
+ embeddings = self.nets["embed_drop"](embeddings)
+
+ return embeddings
+
+
+ def forward(self, **inputs):
+ """
+ Process each set of inputs in its own observation group.
+ Args:
+ inputs (dict): a dictionary of dictionaries with one dictionary per
+ observation group. Each observation group's dictionary should map
+ modality to torch.Tensor batches. Should be consistent with
+ @self.input_obs_group_shapes. First two leading dimensions should
+ be batch and time [B, T, ...] for each tensor.
+ Returns:
+ outputs (dict): dictionary of output torch.Tensors, that corresponds
+ to @self.output_shapes. Leading dimensions will be batch and time [B, T, ...]
+ for each tensor.
+ """
+ for obs_group in self.input_obs_group_shapes:
+ for k in self.input_obs_group_shapes[obs_group]:
+ # first two dimensions should be [B, T] for inputs
+ if inputs[obs_group][k] is None:
+ continue
+ assert inputs[obs_group][k].ndim - 2 == len(self.input_obs_group_shapes[obs_group][k])
+
+ inputs = inputs.copy()
+
+ transformer_encoder_outputs = None
+ transformer_inputs = TensorUtils.time_distributed(
+ inputs, self.nets["encoder"], inputs_as_kwargs=True
+ )
+ assert transformer_inputs.ndim == 3 # [B, T, D]
+
+ if transformer_encoder_outputs is None:
+ transformer_embeddings = self.input_embedding(transformer_inputs)
+ # pass encoded sequences through transformer
+ transformer_encoder_outputs = self.nets["transformer"].forward(transformer_embeddings)
+
+ transformer_outputs = transformer_encoder_outputs
+ # apply decoder to each timestep of sequence to get a dictionary of outputs
+ transformer_outputs = TensorUtils.time_distributed(
+ transformer_outputs, self.nets["decoder"]
+ )
+ transformer_outputs["transformer_encoder_outputs"] = transformer_encoder_outputs
+ return transformer_outputs
+
+ def _to_string(self):
+ """
+ Subclasses should override this method to print out info about network / policy.
+ """
+ return ''
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ indent = ' ' * 4
+ if self._to_string() != '':
+ msg += textwrap.indent("\n" + self._to_string() + "\n", indent)
+ msg += textwrap.indent("\nencoder={}".format(self.nets["encoder"]), indent)
+ msg += textwrap.indent("\n\ntransformer={}".format(self.nets["transformer"]), indent)
+ msg += textwrap.indent("\n\ndecoder={}".format(self.nets["decoder"]), indent)
+ msg = header + '(' + msg + '\n)'
+ return msg
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/models/policy_nets.py b/phantom/submodules/phantom-robomimic/robomimic/models/policy_nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dba1d934cbb6b6a6f2d5c6475d699c48eb2a302
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/models/policy_nets.py
@@ -0,0 +1,1570 @@
+"""
+Contains torch Modules for policy networks. These networks take an
+observation dictionary as input (and possibly additional conditioning,
+such as subgoal or goal dictionaries) and produce action predictions,
+samples, or distributions as outputs. Note that actions
+are assumed to lie in [-1, 1], and most networks will have a final
+tanh activation to help ensure this range.
+"""
+import textwrap
+import numpy as np
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributions as D
+
+import robomimic.utils.tensor_utils as TensorUtils
+from robomimic.models.base_nets import Module
+from robomimic.models.transformers import GPT_Backbone
+from robomimic.models.obs_nets import MIMO_MLP, RNN_MIMO_MLP, MIMO_Transformer, ObservationDecoder
+from robomimic.models.vae_nets import VAE
+from robomimic.models.distributions import TanhWrappedDistribution
+
+
+class ActorNetwork(MIMO_MLP):
+ """
+ A basic policy network that predicts actions from observations.
+ Can optionally be goal conditioned on future observations.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ mlp_layer_dims,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps observation keys to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes.
+
+ goal_shapes (OrderedDict): a dictionary that maps observation keys to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-observation key information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ assert isinstance(obs_shapes, OrderedDict)
+ self.obs_shapes = obs_shapes
+ self.ac_dim = ac_dim
+
+ # set up different observation groups for @MIMO_MLP
+ observation_group_shapes = OrderedDict()
+ observation_group_shapes["obs"] = OrderedDict(self.obs_shapes)
+
+ self._is_goal_conditioned = False
+ if goal_shapes is not None and len(goal_shapes) > 0:
+ assert isinstance(goal_shapes, OrderedDict)
+ self._is_goal_conditioned = True
+ self.goal_shapes = OrderedDict(goal_shapes)
+ observation_group_shapes["goal"] = OrderedDict(self.goal_shapes)
+ else:
+ self.goal_shapes = OrderedDict()
+
+ output_shapes = self._get_output_shapes()
+ super(ActorNetwork, self).__init__(
+ input_obs_group_shapes=observation_group_shapes,
+ output_shapes=output_shapes,
+ layer_dims=mlp_layer_dims,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def _get_output_shapes(self):
+ """
+ Allow subclasses to re-define outputs from @MIMO_MLP, since we won't
+ always directly predict actions, but may instead predict the parameters
+ of a action distribution.
+ """
+ return OrderedDict(action=(self.ac_dim,))
+
+ def output_shape(self, input_shape=None):
+ return [self.ac_dim]
+
+ def forward(self, obs_dict, goal_dict=None):
+ actions = super(ActorNetwork, self).forward(obs=obs_dict, goal=goal_dict)["action"]
+ # apply tanh squashing to ensure actions are in [-1, 1]
+ return torch.tanh(actions)
+
+ def _to_string(self):
+ """Info to pretty print."""
+ return "action_dim={}".format(self.ac_dim)
+
+
+class PerturbationActorNetwork(ActorNetwork):
+ """
+ An action perturbation network - primarily used in BCQ.
+ It takes states and actions and returns action perturbations.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ mlp_layer_dims,
+ perturbation_scale=0.05,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps observation keys to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes.
+
+ perturbation_scale (float): the perturbation network output is always squashed to
+ lie in +/- @perturbation_scale. The final action output is equal to the original
+ input action added to the output perturbation (and clipped to lie in [-1, 1]).
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ self.perturbation_scale = perturbation_scale
+
+ # add in action as a modality
+ new_obs_shapes = OrderedDict(obs_shapes)
+ new_obs_shapes["action"] = (ac_dim,)
+
+ # pass to super class to instantiate network
+ super(PerturbationActorNetwork, self).__init__(
+ obs_shapes=new_obs_shapes,
+ ac_dim=ac_dim,
+ mlp_layer_dims=mlp_layer_dims,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def forward(self, obs_dict, acts, goal_dict=None):
+ """Forward pass through perturbation actor."""
+ # add in actions
+ inputs = dict(obs_dict)
+ inputs["action"] = acts
+ perturbations = super(PerturbationActorNetwork, self).forward(inputs, goal_dict)
+
+ # add perturbations from network to original actions, and ensure the new actions lie in [-1, 1]
+ output_actions = acts + self.perturbation_scale * perturbations
+ output_actions = output_actions.clamp(-1.0, 1.0)
+ return output_actions
+
+ def _to_string(self):
+ """Info to pretty print."""
+ return "action_dim={}, perturbation_scale={}".format(self.ac_dim, self.perturbation_scale)
+
+
+class GaussianActorNetwork(ActorNetwork):
+ """
+ Variant of actor network that learns a diagonal unimodal Gaussian distribution
+ over actions.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ mlp_layer_dims,
+ fixed_std=False,
+ std_activation="softplus",
+ init_last_fc_weight=None,
+ init_std=0.3,
+ mean_limits=(-9.0, 9.0),
+ std_limits=(0.007, 7.5),
+ low_noise_eval=True,
+ use_tanh=False,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes.
+
+ fixed_std (bool): if True, std is not learned, but kept constant at @init_std
+
+ std_activation (None or str): type of activation to use for std deviation. Options are:
+
+ None: no activation applied (not recommended unless using fixed std)
+
+ `'softplus'`: Only applicable if not using fixed std. Softplus activation applied, after which the
+ output is scaled by init_std / softplus(0)
+
+ `'exp'`: Only applicable if not using fixed std. Exp applied; this corresponds to network output
+ as being interpreted as log_std instead of std
+
+ NOTE: In all cases, the final result is clipped to be within @std_limits
+
+ init_last_fc_weight (None or float): if specified, will intialize the final layer network weights to be
+ uniformly sampled from [-init_weight, init_weight]
+
+ init_std (None or float): approximate initial scaling for standard deviation outputs
+ from network. If None
+
+ mean_limits (2-array): (min, max) to clamp final mean output by
+
+ std_limits (2-array): (min, max) to clamp final std output by
+
+ low_noise_eval (float): if True, model will output means of Gaussian distribution
+ at eval time.
+
+ use_tanh (bool): if True, use a tanh-Gaussian distribution
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+
+ # parameters specific to Gaussian actor
+ self.fixed_std = fixed_std
+ self.init_std = init_std
+ self.mean_limits = np.array(mean_limits)
+ self.std_limits = np.array(std_limits)
+
+ # Define activations to use
+ def softplus_scaled(x):
+ out = F.softplus(x)
+ out = out * (self.init_std / F.softplus(torch.zeros(1).to(x.device)))
+ return out
+
+ self.activations = {
+ None: lambda x: x,
+ "softplus": softplus_scaled,
+ "exp": torch.exp,
+ }
+ assert std_activation in self.activations, \
+ "std_activation must be one of: {}; instead got: {}".format(self.activations.keys(), std_activation)
+ self.std_activation = std_activation if not self.fixed_std else None
+
+ self.low_noise_eval = low_noise_eval
+ self.use_tanh = use_tanh
+
+ super(GaussianActorNetwork, self).__init__(
+ obs_shapes=obs_shapes,
+ ac_dim=ac_dim,
+ mlp_layer_dims=mlp_layer_dims,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ # If initialization weight was specified, make sure all final layer network weights are specified correctly
+ if init_last_fc_weight is not None:
+ with torch.no_grad():
+ for name, layer in self.nets["decoder"].nets.items():
+ torch.nn.init.uniform_(layer.weight, -init_last_fc_weight, init_last_fc_weight)
+ torch.nn.init.uniform_(layer.bias, -init_last_fc_weight, init_last_fc_weight)
+
+ def _get_output_shapes(self):
+ """
+ Tells @MIMO_MLP superclass about the output dictionary that should be generated
+ at the last layer. Network outputs parameters of Gaussian distribution.
+ """
+ return OrderedDict(
+ mean=(self.ac_dim,),
+ scale=(self.ac_dim,),
+ )
+
+ def forward_train(self, obs_dict, goal_dict=None):
+ """
+ Return full Gaussian distribution, which is useful for computing
+ quantities necessary at train-time, like log-likelihood, KL
+ divergence, etc.
+
+ Args:
+ obs_dict (dict): batch of observations
+ goal_dict (dict): if not None, batch of goal observations
+
+ Returns:
+ dist (Distribution): Gaussian distribution
+ """
+ out = MIMO_MLP.forward(self, obs=obs_dict, goal=goal_dict)
+ mean = out["mean"]
+ # Use either constant std or learned std depending on setting
+ scale = out["scale"] if not self.fixed_std else torch.ones_like(mean) * self.init_std
+
+ # Clamp the mean
+ mean = torch.clamp(mean, min=self.mean_limits[0], max=self.mean_limits[1])
+
+ # apply tanh squashing to mean if not using tanh-Gaussian to ensure mean is in [-1, 1]
+ if not self.use_tanh:
+ mean = torch.tanh(mean)
+
+ # Calculate scale
+ if self.low_noise_eval and (not self.training):
+ # override std value so that you always approximately sample the mean
+ scale = torch.ones_like(mean) * 1e-4
+ else:
+ # Post-process the scale accordingly
+ scale = self.activations[self.std_activation](scale)
+ # Clamp the scale
+ scale = torch.clamp(scale, min=self.std_limits[0], max=self.std_limits[1])
+
+
+ # the Independent call will make it so that `batch_shape` for dist will be equal to batch size
+ # while `event_shape` will be equal to action dimension - ensuring that log-probability
+ # computations are summed across the action dimension
+ dist = D.Normal(loc=mean, scale=scale)
+ dist = D.Independent(dist, 1)
+
+ if self.use_tanh:
+ # Wrap distribution with Tanh
+ dist = TanhWrappedDistribution(base_dist=dist, scale=1.)
+
+ return dist
+
+ def forward(self, obs_dict, goal_dict=None):
+ """
+ Samples actions from the policy distribution.
+
+ Args:
+ obs_dict (dict): batch of observations
+ goal_dict (dict): if not None, batch of goal observations
+
+ Returns:
+ action (torch.Tensor): batch of actions from policy distribution
+ """
+ dist = self.forward_train(obs_dict, goal_dict)
+ if self.low_noise_eval and (not self.training):
+ if self.use_tanh:
+ # # scaling factor lets us output actions like [-1. 1.] and is consistent with the distribution transform
+ # return (1. + 1e-6) * torch.tanh(dist.base_dist.mean)
+ return torch.tanh(dist.mean)
+ return dist.mean
+ return dist.sample()
+
+ def _to_string(self):
+ """Info to pretty print."""
+ msg = "action_dim={}\nfixed_std={}\nstd_activation={}\ninit_std={}\nmean_limits={}\nstd_limits={}\nlow_noise_eval={}".format(
+ self.ac_dim, self.fixed_std, self.std_activation, self.init_std, self.mean_limits, self.std_limits, self.low_noise_eval)
+ return msg
+
+
+class GMMActorNetwork(ActorNetwork):
+ """
+ Variant of actor network that learns a multimodal Gaussian mixture distribution
+ over actions.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ mlp_layer_dims,
+ num_modes=5,
+ min_std=0.01,
+ std_activation="softplus",
+ low_noise_eval=True,
+ use_tanh=False,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes.
+
+ num_modes (int): number of GMM modes
+
+ min_std (float): minimum std output from network
+
+ std_activation (None or str): type of activation to use for std deviation. Options are:
+
+ `'softplus'`: Softplus activation applied
+
+ `'exp'`: Exp applied; this corresponds to network output being interpreted as log_std instead of std
+
+ low_noise_eval (float): if True, model will sample from GMM with low std, so that
+ one of the GMM modes will be sampled (approximately)
+
+ use_tanh (bool): if True, use a tanh-Gaussian distribution
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+
+ # parameters specific to GMM actor
+ self.num_modes = num_modes
+ self.min_std = min_std
+ self.low_noise_eval = low_noise_eval
+ self.use_tanh = use_tanh
+
+ # Define activations to use
+ self.activations = {
+ "softplus": F.softplus,
+ "exp": torch.exp,
+ }
+ assert std_activation in self.activations, \
+ "std_activation must be one of: {}; instead got: {}".format(self.activations.keys(), std_activation)
+ self.std_activation = std_activation
+
+ super(GMMActorNetwork, self).__init__(
+ obs_shapes=obs_shapes,
+ ac_dim=ac_dim,
+ mlp_layer_dims=mlp_layer_dims,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def _get_output_shapes(self):
+ """
+ Tells @MIMO_MLP superclass about the output dictionary that should be generated
+ at the last layer. Network outputs parameters of GMM distribution.
+ """
+ return OrderedDict(
+ mean=(self.num_modes, self.ac_dim),
+ scale=(self.num_modes, self.ac_dim),
+ logits=(self.num_modes,),
+ )
+
+ def forward_train(self, obs_dict, goal_dict=None):
+ """
+ Return full GMM distribution, which is useful for computing
+ quantities necessary at train-time, like log-likelihood, KL
+ divergence, etc.
+
+ Args:
+ obs_dict (dict): batch of observations
+ goal_dict (dict): if not None, batch of goal observations
+
+ Returns:
+ dist (Distribution): GMM distribution
+ """
+ out = MIMO_MLP.forward(self, obs=obs_dict, goal=goal_dict)
+ means = out["mean"]
+ scales = out["scale"]
+ logits = out["logits"]
+
+ # apply tanh squashing to means if not using tanh-GMM to ensure means are in [-1, 1]
+ if not self.use_tanh:
+ means = torch.tanh(means)
+
+ # Calculate scale
+ if self.low_noise_eval and (not self.training):
+ # low-noise for all Gaussian dists
+ scales = torch.ones_like(means) * 1e-4
+ else:
+ # post-process the scale accordingly
+ scales = self.activations[self.std_activation](scales) + self.min_std
+
+ # mixture components - make sure that `batch_shape` for the distribution is equal
+ # to (batch_size, num_modes) since MixtureSameFamily expects this shape
+ component_distribution = D.Normal(loc=means, scale=scales)
+ component_distribution = D.Independent(component_distribution, 1)
+
+ # unnormalized logits to categorical distribution for mixing the modes
+ mixture_distribution = D.Categorical(logits=logits)
+
+ dist = D.MixtureSameFamily(
+ mixture_distribution=mixture_distribution,
+ component_distribution=component_distribution,
+ )
+
+ if self.use_tanh:
+ # Wrap distribution with Tanh
+ dist = TanhWrappedDistribution(base_dist=dist, scale=1.)
+
+ return dist
+
+ def forward(self, obs_dict, goal_dict=None):
+ """
+ Samples actions from the policy distribution.
+
+ Args:
+ obs_dict (dict): batch of observations
+ goal_dict (dict): if not None, batch of goal observations
+
+ Returns:
+ action (torch.Tensor): batch of actions from policy distribution
+ """
+ dist = self.forward_train(obs_dict, goal_dict)
+ return dist.sample()
+
+ def _to_string(self):
+ """Info to pretty print."""
+ return "action_dim={}\nnum_modes={}\nmin_std={}\nstd_activation={}\nlow_noise_eval={}".format(
+ self.ac_dim, self.num_modes, self.min_std, self.std_activation, self.low_noise_eval)
+
+
+class RNNActorNetwork(RNN_MIMO_MLP):
+ """
+ An RNN policy network that predicts actions from observations.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ mlp_layer_dims,
+ rnn_hidden_dim,
+ rnn_num_layers,
+ rnn_type="LSTM", # [LSTM, GRU]
+ rnn_kwargs=None,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes.
+
+ rnn_hidden_dim (int): RNN hidden dimension
+
+ rnn_num_layers (int): number of RNN layers
+
+ rnn_type (str): [LSTM, GRU]
+
+ rnn_kwargs (dict): kwargs for the torch.nn.LSTM / GRU
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ self.ac_dim = ac_dim
+
+ assert isinstance(obs_shapes, OrderedDict)
+ self.obs_shapes = obs_shapes
+
+ # set up different observation groups for @RNN_MIMO_MLP
+ observation_group_shapes = OrderedDict()
+ observation_group_shapes["obs"] = OrderedDict(self.obs_shapes)
+
+ self._is_goal_conditioned = False
+ if goal_shapes is not None and len(goal_shapes) > 0:
+ assert isinstance(goal_shapes, OrderedDict)
+ self._is_goal_conditioned = True
+ self.goal_shapes = OrderedDict(goal_shapes)
+ observation_group_shapes["goal"] = OrderedDict(self.goal_shapes)
+ else:
+ self.goal_shapes = OrderedDict()
+
+ output_shapes = self._get_output_shapes()
+ super(RNNActorNetwork, self).__init__(
+ input_obs_group_shapes=observation_group_shapes,
+ output_shapes=output_shapes,
+ mlp_layer_dims=mlp_layer_dims,
+ mlp_activation=nn.ReLU,
+ mlp_layer_func=nn.Linear,
+ rnn_hidden_dim=rnn_hidden_dim,
+ rnn_num_layers=rnn_num_layers,
+ rnn_type=rnn_type,
+ rnn_kwargs=rnn_kwargs,
+ per_step=True,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def _get_output_shapes(self):
+ """
+ Allow subclasses to re-define outputs from @RNN_MIMO_MLP, since we won't
+ always directly predict actions, but may instead predict the parameters
+ of a action distribution.
+ """
+ return OrderedDict(action=(self.ac_dim,))
+
+ def output_shape(self, input_shape):
+ # note: @input_shape should be dictionary (key: mod)
+ # infers temporal dimension from input shape
+ mod = list(self.obs_shapes.keys())[0]
+ T = input_shape[mod][0]
+ TensorUtils.assert_size_at_dim(input_shape, size=T, dim=0,
+ msg="RNNActorNetwork: input_shape inconsistent in temporal dimension")
+ return [T, self.ac_dim]
+
+ def forward(self, obs_dict, goal_dict=None, rnn_init_state=None, return_state=False):
+ """
+ Forward a sequence of inputs through the RNN and the per-step network.
+
+ Args:
+ obs_dict (dict): batch of observations - each tensor in the dictionary
+ should have leading dimensions batch and time [B, T, ...]
+ goal_dict (dict): if not None, batch of goal observations
+ rnn_init_state: rnn hidden state, initialize to zero state if set to None
+ return_state (bool): whether to return hidden state
+
+ Returns:
+ actions (torch.Tensor): predicted action sequence
+ rnn_state: return rnn state at the end if return_state is set to True
+ """
+ if self._is_goal_conditioned:
+ assert goal_dict is not None
+ # repeat the goal observation in time to match dimension with obs_dict
+ mod = list(obs_dict.keys())[0]
+ goal_dict = TensorUtils.unsqueeze_expand_at(goal_dict, size=obs_dict[mod].shape[1], dim=1)
+
+ outputs = super(RNNActorNetwork, self).forward(
+ obs=obs_dict, goal=goal_dict, rnn_init_state=rnn_init_state, return_state=return_state)
+
+ if return_state:
+ actions, state = outputs
+ else:
+ actions = outputs
+ state = None
+
+ # apply tanh squashing to ensure actions are in [-1, 1]
+ actions = torch.tanh(actions["action"])
+
+ if return_state:
+ return actions, state
+ else:
+ return actions
+
+ def forward_step(self, obs_dict, goal_dict=None, rnn_state=None):
+ """
+ Unroll RNN over single timestep to get actions.
+
+ Args:
+ obs_dict (dict): batch of observations. Should not contain
+ time dimension.
+ goal_dict (dict): if not None, batch of goal observations
+ rnn_state: rnn hidden state, initialize to zero state if set to None
+
+ Returns:
+ actions (torch.Tensor): batch of actions - does not contain time dimension
+ state: updated rnn state
+ """
+ obs_dict = TensorUtils.to_sequence(obs_dict)
+ action, state = self.forward(
+ obs_dict, goal_dict, rnn_init_state=rnn_state, return_state=True)
+ return action[:, 0], state
+
+ def _to_string(self):
+ """Info to pretty print."""
+ return "action_dim={}".format(self.ac_dim)
+
+
+class RNNGMMActorNetwork(RNNActorNetwork):
+ """
+ An RNN GMM policy network that predicts sequences of action distributions from observation sequences.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ mlp_layer_dims,
+ rnn_hidden_dim,
+ rnn_num_layers,
+ rnn_type="LSTM", # [LSTM, GRU]
+ rnn_kwargs=None,
+ num_modes=5,
+ min_std=0.01,
+ std_activation="softplus",
+ low_noise_eval=True,
+ use_tanh=False,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+
+ rnn_hidden_dim (int): RNN hidden dimension
+
+ rnn_num_layers (int): number of RNN layers
+
+ rnn_type (str): [LSTM, GRU]
+
+ rnn_kwargs (dict): kwargs for the torch.nn.LSTM / GRU
+
+ num_modes (int): number of GMM modes
+
+ min_std (float): minimum std output from network
+
+ std_activation (None or str): type of activation to use for std deviation. Options are:
+
+ `'softplus'`: Softplus activation applied
+
+ `'exp'`: Exp applied; this corresponds to network output being interpreted as log_std instead of std
+
+ low_noise_eval (float): if True, model will sample from GMM with low std, so that
+ one of the GMM modes will be sampled (approximately)
+
+ use_tanh (bool): if True, use a tanh-Gaussian distribution
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+
+ # parameters specific to GMM actor
+ self.num_modes = num_modes
+ self.min_std = min_std
+ self.low_noise_eval = low_noise_eval
+ self.use_tanh = use_tanh
+
+ # Define activations to use
+ self.activations = {
+ "softplus": F.softplus,
+ "exp": torch.exp,
+ }
+ assert std_activation in self.activations, \
+ "std_activation must be one of: {}; instead got: {}".format(self.activations.keys(), std_activation)
+ self.std_activation = std_activation
+
+ super(RNNGMMActorNetwork, self).__init__(
+ obs_shapes=obs_shapes,
+ ac_dim=ac_dim,
+ mlp_layer_dims=mlp_layer_dims,
+ rnn_hidden_dim=rnn_hidden_dim,
+ rnn_num_layers=rnn_num_layers,
+ rnn_type=rnn_type,
+ rnn_kwargs=rnn_kwargs,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def _get_output_shapes(self):
+ """
+ Tells @MIMO_MLP superclass about the output dictionary that should be generated
+ at the last layer. Network outputs parameters of GMM distribution.
+ """
+ return OrderedDict(
+ mean=(self.num_modes, self.ac_dim),
+ scale=(self.num_modes, self.ac_dim),
+ logits=(self.num_modes,),
+ )
+
+ def forward_train(self, obs_dict, goal_dict=None, rnn_init_state=None, return_state=False):
+ """
+ Return full GMM distribution, which is useful for computing
+ quantities necessary at train-time, like log-likelihood, KL
+ divergence, etc.
+
+ Args:
+ obs_dict (dict): batch of observations
+ goal_dict (dict): if not None, batch of goal observations
+ rnn_init_state: rnn hidden state, initialize to zero state if set to None
+ return_state (bool): whether to return hidden state
+
+ Returns:
+ dists (Distribution): sequence of GMM distributions over the timesteps
+ rnn_state: return rnn state at the end if return_state is set to True
+ """
+ if self._is_goal_conditioned:
+ assert goal_dict is not None
+ # repeat the goal observation in time to match dimension with obs_dict
+ mod = list(obs_dict.keys())[0]
+ goal_dict = TensorUtils.unsqueeze_expand_at(goal_dict, size=obs_dict[mod].shape[1], dim=1)
+
+ outputs = RNN_MIMO_MLP.forward(
+ self, obs=obs_dict, goal=goal_dict, rnn_init_state=rnn_init_state, return_state=return_state)
+
+ if return_state:
+ outputs, state = outputs
+ else:
+ state = None
+
+ means = outputs["mean"]
+ scales = outputs["scale"]
+ logits = outputs["logits"]
+
+ # apply tanh squashing to mean if not using tanh-GMM to ensure means are in [-1, 1]
+ if not self.use_tanh:
+ means = torch.tanh(means)
+
+ if self.low_noise_eval and (not self.training):
+ # low-noise for all Gaussian dists
+ scales = torch.ones_like(means) * 1e-4
+ else:
+ # post-process the scale accordingly
+ scales = self.activations[self.std_activation](scales) + self.min_std
+
+ # mixture components - make sure that `batch_shape` for the distribution is equal
+ # to (batch_size, timesteps, num_modes) since MixtureSameFamily expects this shape
+ component_distribution = D.Normal(loc=means, scale=scales)
+ component_distribution = D.Independent(component_distribution, 1) # shift action dim to event shape
+
+ # unnormalized logits to categorical distribution for mixing the modes
+ mixture_distribution = D.Categorical(logits=logits)
+
+ dists = D.MixtureSameFamily(
+ mixture_distribution=mixture_distribution,
+ component_distribution=component_distribution,
+ )
+
+ if self.use_tanh:
+ # Wrap distribution with Tanh
+ dists = TanhWrappedDistribution(base_dist=dists, scale=1.)
+
+ if return_state:
+ return dists, state
+ else:
+ return dists
+
+ def forward(self, obs_dict, goal_dict=None, rnn_init_state=None, return_state=False):
+ """
+ Samples actions from the policy distribution.
+
+ Args:
+ obs_dict (dict): batch of observations
+ goal_dict (dict): if not None, batch of goal observations
+
+ Returns:
+ action (torch.Tensor): batch of actions from policy distribution
+ """
+ out = self.forward_train(obs_dict=obs_dict, goal_dict=goal_dict, rnn_init_state=rnn_init_state, return_state=return_state)
+ if return_state:
+ ad, state = out
+ return ad.sample(), state
+ return out.sample()
+
+ def forward_train_step(self, obs_dict, goal_dict=None, rnn_state=None):
+ """
+ Unroll RNN over single timestep to get action GMM distribution, which
+ is useful for computing quantities necessary at train-time, like
+ log-likelihood, KL divergence, etc.
+
+ Args:
+ obs_dict (dict): batch of observations. Should not contain
+ time dimension.
+ goal_dict (dict): if not None, batch of goal observations
+ rnn_state: rnn hidden state, initialize to zero state if set to None
+
+ Returns:
+ ad (Distribution): GMM action distributions
+ state: updated rnn state
+ """
+ obs_dict = TensorUtils.to_sequence(obs_dict)
+ ad, state = self.forward_train(
+ obs_dict, goal_dict, rnn_init_state=rnn_state, return_state=True)
+
+ # to squeeze time dimension, make another action distribution
+ assert ad.component_distribution.base_dist.loc.shape[1] == 1
+ assert ad.component_distribution.base_dist.scale.shape[1] == 1
+ assert ad.mixture_distribution.logits.shape[1] == 1
+ component_distribution = D.Normal(
+ loc=ad.component_distribution.base_dist.loc.squeeze(1),
+ scale=ad.component_distribution.base_dist.scale.squeeze(1),
+ )
+ component_distribution = D.Independent(component_distribution, 1)
+ mixture_distribution = D.Categorical(logits=ad.mixture_distribution.logits.squeeze(1))
+ ad = D.MixtureSameFamily(
+ mixture_distribution=mixture_distribution,
+ component_distribution=component_distribution,
+ )
+ return ad, state
+
+ def forward_step(self, obs_dict, goal_dict=None, rnn_state=None):
+ """
+ Unroll RNN over single timestep to get sampled actions.
+
+ Args:
+ obs_dict (dict): batch of observations. Should not contain
+ time dimension.
+ goal_dict (dict): if not None, batch of goal observations
+ rnn_state: rnn hidden state, initialize to zero state if set to None
+
+ Returns:
+ acts (torch.Tensor): batch of actions - does not contain time dimension
+ state: updated rnn state
+ """
+ obs_dict = TensorUtils.to_sequence(obs_dict)
+ acts, state = self.forward(
+ obs_dict, goal_dict, rnn_init_state=rnn_state, return_state=True)
+ assert acts.shape[1] == 1
+ return acts[:, 0], state
+
+ def _to_string(self):
+ """Info to pretty print."""
+ msg = "action_dim={}, std_activation={}, low_noise_eval={}, num_nodes={}, min_std={}".format(
+ self.ac_dim, self.std_activation, self.low_noise_eval, self.num_modes, self.min_std)
+ return msg
+
+
+class TransformerActorNetwork(MIMO_Transformer):
+ """
+ An Transformer policy network that predicts actions from observation sequences (assumed to be frame stacked
+ from previous observations) and possible from previous actions as well (in an autoregressive manner).
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ transformer_embed_dim,
+ transformer_num_layers,
+ transformer_num_heads,
+ transformer_context_length,
+ transformer_emb_dropout=0.1,
+ transformer_attn_dropout=0.1,
+ transformer_block_output_dropout=0.1,
+ transformer_sinusoidal_embedding=False,
+ transformer_activation="gelu",
+ transformer_nn_parameter_for_timesteps=False,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ transformer_embed_dim (int): dimension for embeddings used by transformer
+
+ transformer_num_layers (int): number of transformer blocks to stack
+
+ transformer_num_heads (int): number of attention heads for each
+ transformer block - must divide @transformer_embed_dim evenly. Self-attention is
+ computed over this many partitions of the embedding dimension separately.
+
+ transformer_context_length (int): expected length of input sequences
+
+ transformer_embedding_dropout (float): dropout probability for embedding inputs in transformer
+
+ transformer_attn_dropout (float): dropout probability for attention outputs for each transformer block
+
+ transformer_block_output_dropout (float): dropout probability for final outputs for each transformer block
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ self.ac_dim = ac_dim
+
+ assert isinstance(obs_shapes, OrderedDict)
+ self.obs_shapes = obs_shapes
+
+ self.transformer_nn_parameter_for_timesteps = transformer_nn_parameter_for_timesteps
+
+ # set up different observation groups for @RNN_MIMO_MLP
+ observation_group_shapes = OrderedDict()
+ observation_group_shapes["obs"] = OrderedDict(self.obs_shapes)
+
+ self._is_goal_conditioned = False
+ if goal_shapes is not None and len(goal_shapes) > 0:
+ assert isinstance(goal_shapes, OrderedDict)
+ self._is_goal_conditioned = True
+ self.goal_shapes = OrderedDict(goal_shapes)
+ observation_group_shapes["goal"] = OrderedDict(self.goal_shapes)
+ else:
+ self.goal_shapes = OrderedDict()
+
+ output_shapes = self._get_output_shapes()
+ super(TransformerActorNetwork, self).__init__(
+ input_obs_group_shapes=observation_group_shapes,
+ output_shapes=output_shapes,
+ transformer_embed_dim=transformer_embed_dim,
+ transformer_num_layers=transformer_num_layers,
+ transformer_num_heads=transformer_num_heads,
+ transformer_context_length=transformer_context_length,
+ transformer_emb_dropout=transformer_emb_dropout,
+ transformer_attn_dropout=transformer_attn_dropout,
+ transformer_block_output_dropout=transformer_block_output_dropout,
+ transformer_sinusoidal_embedding=transformer_sinusoidal_embedding,
+ transformer_activation=transformer_activation,
+ transformer_nn_parameter_for_timesteps=transformer_nn_parameter_for_timesteps,
+
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def _get_output_shapes(self):
+ """
+ Allow subclasses to re-define outputs from @MIMO_Transformer, since we won't
+ always directly predict actions, but may instead predict the parameters
+ of a action distribution.
+ """
+ output_shapes = OrderedDict(action=(self.ac_dim,))
+ return output_shapes
+
+ def output_shape(self, input_shape):
+ # note: @input_shape should be dictionary (key: mod)
+ # infers temporal dimension from input shape
+ mod = list(self.obs_shapes.keys())[0]
+ T = input_shape[mod][0]
+ TensorUtils.assert_size_at_dim(input_shape, size=T, dim=0,
+ msg="TransformerActorNetwork: input_shape inconsistent in temporal dimension")
+ return [T, self.ac_dim]
+
+ def forward(self, obs_dict, actions=None, goal_dict=None):
+ """
+ Forward a sequence of inputs through the Transformer.
+ Args:
+ obs_dict (dict): batch of observations - each tensor in the dictionary
+ should have leading dimensions batch and time [B, T, ...]
+ actions (torch.Tensor): batch of actions of shape [B, T, D]
+ goal_dict (dict): if not None, batch of goal observations
+ Returns:
+ outputs (torch.Tensor or dict): contains predicted action sequence, or dictionary
+ with predicted action sequence and predicted observation sequences
+ """
+ if self._is_goal_conditioned:
+ assert goal_dict is not None
+ # repeat the goal observation in time to match dimension with obs_dict
+ mod = list(obs_dict.keys())[0]
+ goal_dict = TensorUtils.unsqueeze_expand_at(goal_dict, size=obs_dict[mod].shape[1], dim=1)
+
+ forward_kwargs = dict(obs=obs_dict, goal=goal_dict)
+ outputs = super(TransformerActorNetwork, self).forward(**forward_kwargs)
+
+ # apply tanh squashing to ensure actions are in [-1, 1]
+ outputs["action"] = torch.tanh(outputs["action"])
+
+ return outputs["action"] # only action sequences
+
+ def _to_string(self):
+ """Info to pretty print."""
+ return "action_dim={}".format(self.ac_dim)
+
+
+class TransformerGMMActorNetwork(TransformerActorNetwork):
+ """
+ A Transformer GMM policy network that predicts sequences of action distributions from observation
+ sequences (assumed to be frame stacked from previous observations).
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ transformer_embed_dim,
+ transformer_num_layers,
+ transformer_num_heads,
+ transformer_context_length,
+ transformer_emb_dropout=0.1,
+ transformer_attn_dropout=0.1,
+ transformer_block_output_dropout=0.1,
+ transformer_sinusoidal_embedding=False,
+ transformer_activation="gelu",
+ transformer_nn_parameter_for_timesteps=False,
+ num_modes=5,
+ min_std=0.01,
+ std_activation="softplus",
+ low_noise_eval=True,
+ use_tanh=False,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ transformer_embed_dim (int): dimension for embeddings used by transformer
+
+ transformer_num_layers (int): number of transformer blocks to stack
+
+ transformer_num_heads (int): number of attention heads for each
+ transformer block - must divide @transformer_embed_dim evenly. Self-attention is
+ computed over this many partitions of the embedding dimension separately.
+
+ transformer_context_length (int): expected length of input sequences
+
+ transformer_embedding_dropout (float): dropout probability for embedding inputs in transformer
+
+ transformer_attn_dropout (float): dropout probability for attention outputs for each transformer block
+
+ transformer_block_output_dropout (float): dropout probability for final outputs for each transformer block
+
+ num_modes (int): number of GMM modes
+
+ min_std (float): minimum std output from network
+
+ std_activation (None or str): type of activation to use for std deviation. Options are:
+
+ `'softplus'`: Softplus activation applied
+
+ `'exp'`: Exp applied; this corresponds to network output being interpreted as log_std instead of std
+
+ low_noise_eval (float): if True, model will sample from GMM with low std, so that
+ one of the GMM modes will be sampled (approximately)
+
+ use_tanh (bool): if True, use a tanh-Gaussian distribution
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+
+ # parameters specific to GMM actor
+ self.num_modes = num_modes
+ self.min_std = min_std
+ self.low_noise_eval = low_noise_eval
+ self.use_tanh = use_tanh
+
+ # Define activations to use
+ self.activations = {
+ "softplus": F.softplus,
+ "exp": torch.exp,
+ }
+ assert std_activation in self.activations, \
+ "std_activation must be one of: {}; instead got: {}".format(self.activations.keys(), std_activation)
+ self.std_activation = std_activation
+
+ super(TransformerGMMActorNetwork, self).__init__(
+ obs_shapes=obs_shapes,
+ ac_dim=ac_dim,
+ transformer_embed_dim=transformer_embed_dim,
+ transformer_num_layers=transformer_num_layers,
+ transformer_num_heads=transformer_num_heads,
+ transformer_context_length=transformer_context_length,
+ transformer_emb_dropout=transformer_emb_dropout,
+ transformer_attn_dropout=transformer_attn_dropout,
+ transformer_block_output_dropout=transformer_block_output_dropout,
+ transformer_sinusoidal_embedding=transformer_sinusoidal_embedding,
+ transformer_activation=transformer_activation,
+ transformer_nn_parameter_for_timesteps=transformer_nn_parameter_for_timesteps,
+ encoder_kwargs=encoder_kwargs,
+ goal_shapes=goal_shapes,
+ )
+
+ def _get_output_shapes(self):
+ """
+ Tells @MIMO_Transformer superclass about the output dictionary that should be generated
+ at the last layer. Network outputs parameters of GMM distribution.
+ """
+ return OrderedDict(
+ mean=(self.num_modes, self.ac_dim),
+ scale=(self.num_modes, self.ac_dim),
+ logits=(self.num_modes,),
+ )
+
+ def forward_train(self, obs_dict, actions=None, goal_dict=None, low_noise_eval=None):
+ """
+ Return full GMM distribution, which is useful for computing
+ quantities necessary at train-time, like log-likelihood, KL
+ divergence, etc.
+ Args:
+ obs_dict (dict): batch of observations
+ actions (torch.Tensor): batch of actions
+ goal_dict (dict): if not None, batch of goal observations
+ Returns:
+ dists (Distribution): sequence of GMM distributions over the timesteps
+ """
+ if self._is_goal_conditioned:
+ assert goal_dict is not None
+ # repeat the goal observation in time to match dimension with obs_dict
+ mod = list(obs_dict.keys())[0]
+ goal_dict = TensorUtils.unsqueeze_expand_at(goal_dict, size=obs_dict[mod].shape[1], dim=1)
+
+ forward_kwargs = dict(obs=obs_dict, goal=goal_dict)
+
+ outputs = MIMO_Transformer.forward(self, **forward_kwargs)
+
+ means = outputs["mean"]
+ scales = outputs["scale"]
+ logits = outputs["logits"]
+
+ # apply tanh squashing to mean if not using tanh-GMM to ensure means are in [-1, 1]
+ if not self.use_tanh:
+ means = torch.tanh(means)
+
+ if low_noise_eval is None:
+ low_noise_eval = self.low_noise_eval
+ if low_noise_eval and (not self.training):
+ # low-noise for all Gaussian dists
+ scales = torch.ones_like(means) * 1e-4
+ else:
+ # post-process the scale accordingly
+ scales = self.activations[self.std_activation](scales) + self.min_std
+
+ # mixture components - make sure that `batch_shape` for the distribution is equal
+ # to (batch_size, timesteps, num_modes) since MixtureSameFamily expects this shape
+ component_distribution = D.Normal(loc=means, scale=scales)
+ component_distribution = D.Independent(component_distribution, 1) # shift action dim to event shape
+
+ # unnormalized logits to categorical distribution for mixing the modes
+ mixture_distribution = D.Categorical(logits=logits)
+
+ dists = D.MixtureSameFamily(
+ mixture_distribution=mixture_distribution,
+ component_distribution=component_distribution,
+ )
+
+ if self.use_tanh:
+ # Wrap distribution with Tanh
+ dists = TanhWrappedDistribution(base_dist=dists, scale=1.)
+
+ return dists
+
+ def forward(self, obs_dict, actions=None, goal_dict=None):
+ """
+ Samples actions from the policy distribution.
+ Args:
+ obs_dict (dict): batch of observations
+ actions (torch.Tensor): batch of actions
+ goal_dict (dict): if not None, batch of goal observations
+ Returns:
+ action (torch.Tensor): batch of actions from policy distribution
+ """
+ out = self.forward_train(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict)
+ return out.sample()
+
+ def _to_string(self):
+ """Info to pretty print."""
+ msg = "action_dim={}, std_activation={}, low_noise_eval={}, num_nodes={}, min_std={}".format(
+ self.ac_dim, self.std_activation, self.low_noise_eval, self.num_modes, self.min_std)
+ return msg
+
+
+class VAEActor(Module):
+ """
+ A VAE that models a distribution of actions conditioned on observations.
+ The VAE prior and decoder are used at test-time as the policy.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ encoder_layer_dims,
+ decoder_layer_dims,
+ latent_dim,
+ device,
+ decoder_is_conditioned=True,
+ decoder_reconstruction_sum_across_elements=False,
+ latent_clip=None,
+ prior_learn=False,
+ prior_is_conditioned=False,
+ prior_layer_dims=(),
+ prior_use_gmm=False,
+ prior_gmm_num_modes=10,
+ prior_gmm_learn_weights=False,
+ prior_use_categorical=False,
+ prior_categorical_dim=10,
+ prior_categorical_gumbel_softmax_hard=False,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ super(VAEActor, self).__init__()
+
+ self.obs_shapes = obs_shapes
+ self.ac_dim = ac_dim
+ action_shapes = OrderedDict(action=(self.ac_dim,))
+
+ # ensure VAE decoder will squash actions into [-1, 1]
+ output_squash = ['action']
+ output_scales = OrderedDict(action=1.)
+
+ self._vae = VAE(
+ input_shapes=action_shapes,
+ output_shapes=action_shapes,
+ encoder_layer_dims=encoder_layer_dims,
+ decoder_layer_dims=decoder_layer_dims,
+ latent_dim=latent_dim,
+ device=device,
+ condition_shapes=self.obs_shapes,
+ decoder_is_conditioned=decoder_is_conditioned,
+ decoder_reconstruction_sum_across_elements=decoder_reconstruction_sum_across_elements,
+ latent_clip=latent_clip,
+ output_squash=output_squash,
+ output_scales=output_scales,
+ prior_learn=prior_learn,
+ prior_is_conditioned=prior_is_conditioned,
+ prior_layer_dims=prior_layer_dims,
+ prior_use_gmm=prior_use_gmm,
+ prior_gmm_num_modes=prior_gmm_num_modes,
+ prior_gmm_learn_weights=prior_gmm_learn_weights,
+ prior_use_categorical=prior_use_categorical,
+ prior_categorical_dim=prior_categorical_dim,
+ prior_categorical_gumbel_softmax_hard=prior_categorical_gumbel_softmax_hard,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def encode(self, actions, obs_dict, goal_dict=None):
+ """
+ Args:
+ actions (torch.Tensor): a batch of actions
+
+ obs_dict (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to the observation modalities
+ used for conditioning in either the decoder or the prior (or both).
+
+ goal_dict (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to goal modalities.
+
+ Returns:
+ posterior params (dict): dictionary with the following keys:
+
+ mean (torch.Tensor): posterior encoder means
+
+ logvar (torch.Tensor): posterior encoder logvars
+ """
+ inputs = OrderedDict(action=actions)
+ return self._vae.encode(inputs=inputs, conditions=obs_dict, goals=goal_dict)
+
+ def decode(self, obs_dict=None, goal_dict=None, z=None, n=None):
+ """
+ Thin wrapper around @VaeNets.VAE implementation.
+
+ Args:
+ obs_dict (dict): a dictionary that maps modalities to torch.Tensor
+ batches. Only needs to be provided if @decoder_is_conditioned
+ or @z is None (since the prior will require it to generate z).
+
+ goal_dict (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to goal modalities.
+
+ z (torch.Tensor): if provided, these latents are used to generate
+ reconstructions from the VAE, and the prior is not sampled.
+
+ n (int): this argument is used to specify the number of samples to
+ generate from the prior. Only required if @z is None - i.e.
+ sampling takes place
+
+ Returns:
+ recons (dict): dictionary of reconstructed inputs (this will be a dictionary
+ with a single "action" key)
+ """
+ return self._vae.decode(conditions=obs_dict, goals=goal_dict, z=z, n=n)
+
+ def sample_prior(self, obs_dict=None, goal_dict=None, n=None):
+ """
+ Thin wrapper around @VaeNets.VAE implementation.
+
+ Args:
+ n (int): this argument is used to specify the number
+ of samples to generate from the prior.
+
+ obs_dict (dict): a dictionary that maps modalities to torch.Tensor
+ batches. Only needs to be provided if @prior_is_conditioned.
+
+ goal_dict (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to goal modalities.
+
+ Returns:
+ z (torch.Tensor): latents sampled from the prior
+ """
+ return self._vae.sample_prior(n=n, conditions=obs_dict, goals=goal_dict)
+
+ def set_gumbel_temperature(self, temperature):
+ """
+ Used by external algorithms to schedule Gumbel-Softmax temperature,
+ which is used during reparametrization at train-time. Should only be
+ used if @prior_use_categorical is True.
+ """
+ self._vae.set_gumbel_temperature(temperature)
+
+ def get_gumbel_temperature(self):
+ """
+ Return current Gumbel-Softmax temperature. Should only be used if
+ @prior_use_categorical is True.
+ """
+ return self._vae.get_gumbel_temperature()
+
+ def output_shape(self, input_shape=None):
+ """
+ This implementation is required by the Module superclass, but is unused since we
+ never chain this module to other ones.
+ """
+ return [self.ac_dim]
+
+ def forward_train(self, actions, obs_dict, goal_dict=None, freeze_encoder=False):
+ """
+ A full pass through the VAE network used during training to construct KL
+ and reconstruction losses. See @VAE class for more info.
+
+ Args:
+ actions (torch.Tensor): a batch of actions
+
+ obs_dict (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to the observation modalities
+ used for conditioning in either the decoder or the prior (or both).
+
+ goal_dict (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to goal modalities.
+
+ Returns:
+ vae_outputs (dict): a dictionary that contains the following outputs.
+
+ encoder_params (dict): parameters for the posterior distribution
+ from the encoder forward pass
+
+ encoder_z (torch.Tensor): latents sampled from the encoder posterior
+
+ decoder_outputs (dict): action reconstructions from the decoder
+
+ kl_loss (torch.Tensor): KL loss over the batch of data
+
+ reconstruction_loss (torch.Tensor): reconstruction loss over the batch of data
+ """
+ action_inputs = OrderedDict(action=actions)
+ return self._vae.forward(
+ inputs=action_inputs,
+ outputs=action_inputs,
+ conditions=obs_dict,
+ goals=goal_dict,
+ freeze_encoder=freeze_encoder)
+
+ def forward(self, obs_dict, goal_dict=None, z=None):
+ """
+ Samples actions from the policy distribution.
+
+ Args:
+ obs_dict (dict): batch of observations
+ goal_dict (dict): if not None, batch of goal observations
+ z (torch.Tensor): if not None, use the provided batch of latents instead
+ of sampling from the prior
+
+ Returns:
+ action (torch.Tensor): batch of actions from policy distribution
+ """
+ n = None
+ if z is None:
+ # prior will be sampled - so we must provide number of samples explicitly
+ mod = list(obs_dict.keys())[0]
+ n = obs_dict[mod].shape[0]
+ return self.decode(obs_dict=obs_dict, goal_dict=goal_dict, z=z, n=n)["action"]
diff --git a/phantom/submodules/phantom-robomimic/robomimic/models/transformers.py b/phantom/submodules/phantom-robomimic/robomimic/models/transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..309bff301d02ad561a34021dbea5d370249cef0f
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/models/transformers.py
@@ -0,0 +1,426 @@
+"""
+Implementation of transformers, mostly based on Andrej's minGPT model.
+See https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
+for more details.
+"""
+
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from robomimic.models.base_nets import Module
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+
+class GEGLU(nn.Module):
+ """
+ References:
+ Shazeer et al., "GLU Variants Improve Transformer," 2020.
+ https://arxiv.org/abs/2002.05202
+ Implementation: https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py
+ """
+
+ def geglu(self, x):
+ assert x.shape[-1] % 2 == 0
+ a, b = x.chunk(2, dim=-1)
+ return a * F.gelu(b)
+
+ def forward(self, x):
+ return self.geglu(x)
+
+
+class PositionalEncoding(nn.Module):
+ """
+ Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
+ """
+
+ def __init__(self, embed_dim):
+ """
+ Standard sinusoidal positional encoding scheme in transformers.
+
+ Positional encoding of the k'th position in the sequence is given by:
+ p(k, 2i) = sin(k/n^(i/d))
+ p(k, 2i+1) = sin(k/n^(i/d))
+
+ n: set to 10K in original Transformer paper
+ d: the embedding dimension
+ i: positions along the projected embedding space (ranges from 0 to d/2)
+
+ Args:
+ embed_dim: The number of dimensions to project the timesteps into.
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+
+ def forward(self, x):
+ """
+ Input timestep of shape BxT
+ """
+ position = x
+
+ # computing 1/n^(i/d) in log space and then exponentiating and fixing the shape
+ div_term = (
+ torch.exp(
+ torch.arange(0, self.embed_dim, 2, device=x.device)
+ * (-math.log(10000.0) / self.embed_dim)
+ )
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .repeat(x.shape[0], x.shape[1], 1)
+ )
+ pe = torch.zeros((x.shape[0], x.shape[1], self.embed_dim), device=x.device)
+ pe[:, :, 0::2] = torch.sin(position.unsqueeze(-1) * div_term)
+ pe[:, :, 1::2] = torch.cos(position.unsqueeze(-1) * div_term)
+ return pe.detach()
+
+
+class CausalSelfAttention(Module):
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ context_length,
+ attn_dropout=0.1,
+ output_dropout=0.1,
+ ):
+ """
+ Multi-head masked self-attention layer + projection (MLP layer).
+
+ For normal self-attention (@num_heads = 1), every single input in the sequence is
+ mapped to a key, query, and value embedding of size @embed_dim. For each input,
+ its query vector is compared (using dot-product) with all other key vectors in the
+ sequence, and softmax normalized to compute an attention over all members of the
+ sequence. This is used to take a linear combination of corresponding value embeddings.
+
+ The @num_heads argument is for multi-head attention, where the self-attention operation above
+ is performed in parallel over equal size partitions of the @embed_dim, allowing for different
+ portions of the embedding dimension to model different kinds of attention. The attention
+ output for each head is concatenated together.
+
+ Finally, we use a causal mask here to ensure that each output only depends on inputs that come
+ before it.
+
+ Args:
+ embed_dim (int): dimension of embeddings to use for keys, queries, and values
+ used in self-attention
+
+ num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is
+ computed over this many partitions of the embedding dimension separately.
+
+ context_length (int): expected length of input sequences
+
+ attn_dropout (float): dropout probability for attention outputs
+
+ output_dropout (float): dropout probability for final outputs
+ """
+ super(CausalSelfAttention, self).__init__()
+
+ assert (
+ embed_dim % num_heads == 0
+ ), "num_heads: {} does not divide embed_dim: {} exactly".format(num_heads, embed_dim)
+
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.context_length = context_length
+ self.attn_dropout = attn_dropout
+ self.output_dropout = output_dropout
+ self.nets = nn.ModuleDict()
+
+ # projection layers for key, query, value, across all attention heads
+ self.nets["qkv"] = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
+
+ # dropout layers
+ self.nets["attn_dropout"] = nn.Dropout(self.attn_dropout)
+ self.nets["output_dropout"] = nn.Dropout(self.output_dropout)
+
+ # output layer
+ self.nets["output"] = nn.Linear(self.embed_dim, self.embed_dim)
+
+ # causal mask (ensures attention is only over previous inputs) - just a lower triangular matrix of 1s
+ mask = torch.tril(torch.ones(context_length, context_length)).view(
+ 1, 1, context_length, context_length
+ )
+ self.register_buffer("mask", mask)
+
+ def forward(self, x):
+ """
+ Forward pass through Self-Attention block.
+ Input should be shape (B, T, D) where B is batch size, T is seq length (@self.context_length), and
+ D is input dimension (@self.embed_dim).
+ """
+
+ # enforce shape consistency
+ assert len(x.shape) == 3
+ B, T, D = x.shape
+ assert (
+ T <= self.context_length
+ ), "self-attention module can only handle sequences up to {} in length but got length {}".format(
+ self.context_length, T
+ )
+ assert D == self.embed_dim
+ NH = self.num_heads # number of attention heads
+ DH = D // NH # embed dimension for each attention head
+
+ # compute key, query, and value vectors for each member of sequence, and split across attention heads
+ qkv = self.nets["qkv"](x)
+ q, k, v = torch.chunk(qkv, 3, dim=-1)
+ k = k.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH]
+ q = q.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH]
+ v = v.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH]
+
+ # causal self-attention mechanism
+
+ # batched matrix multiplication between queries and keys to get all pair-wise dot-products.
+ # We broadcast across batch and attention heads and get pair-wise dot-products between all pairs of timesteps
+ # [B, NH, T, DH] x [B, NH, DH, T] -> [B, NH, T, T]
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+
+ # use mask to replace entries in dot products with negative inf to ensure they don't contribute to softmax,
+ # then take softmax over last dimension to end up with attention score for each member of sequence.
+ # Note the use of [:T, :T] - this makes it so we can handle sequences less than @self.context_length in length.
+ att = att.masked_fill(self.mask[..., :T, :T] == 0, float("-inf"))
+ att = F.softmax(
+ att, dim=-1
+ ) # shape [B, NH, T, T], last dimension has score over all T for each sequence member
+
+ # dropout on attention
+ att = self.nets["attn_dropout"](att)
+
+ # take weighted sum of value vectors over whole sequence according to attention, with batched matrix multiplication
+ # [B, NH, T, T] x [B, NH, T, DH] -> [B, NH, T, DH]
+ y = att @ v
+ # reshape [B, NH, T, DH] -> [B, T, NH, DH] -> [B, T, NH * DH] = [B, T, D]
+ y = y.transpose(1, 2).contiguous().view(B, T, D)
+
+ # pass through output layer + dropout
+ y = self.nets["output"](y)
+ y = self.nets["output_dropout"](y)
+ return y
+
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+
+ # this module doesn't modify the size of the input, it goes from (B, T, D) -> (B, T, D)
+ return list(input_shape)
+
+
+class SelfAttentionBlock(Module):
+ """
+ A single Transformer Block, that can be chained together repeatedly.
+ It consists of a @CausalSelfAttention module and a small MLP, along with
+ layer normalization and residual connections on each input.
+ """
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ context_length,
+ attn_dropout=0.1,
+ output_dropout=0.1,
+ activation=nn.GELU(),
+ ):
+ """
+ Args:
+ embed_dim (int): dimension of embeddings to use for keys, queries, and values
+ used in self-attention
+
+ num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is
+ computed over this many partitions of the embedding dimension separately.
+
+ context_length (int): expected length of input sequences
+
+ attn_dropout (float): dropout probability for attention outputs
+
+ output_dropout (float): dropout probability for final outputs
+
+ activation (str): string denoting the activation function to use in each transformer block
+ """
+ super(SelfAttentionBlock, self).__init__()
+
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.context_length = context_length
+ self.attn_dropout = attn_dropout
+ self.output_dropout = output_dropout
+ self.nets = nn.ModuleDict()
+
+ # self-attention block
+ self.nets["attention"] = CausalSelfAttention(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ context_length=context_length,
+ attn_dropout=attn_dropout,
+ output_dropout=output_dropout,
+ )
+
+ if type(activation) == GEGLU:
+ mult = 2
+ else:
+ mult = 1
+
+ # small 2-layer MLP
+ self.nets["mlp"] = nn.Sequential(
+ nn.Linear(embed_dim, 4 * embed_dim * mult),
+ activation,
+ nn.Linear(4 * embed_dim, embed_dim),
+ nn.Dropout(output_dropout)
+ )
+
+ # layer normalization for inputs to self-attention module and MLP
+ self.nets["ln1"] = nn.LayerNorm(embed_dim)
+ self.nets["ln2"] = nn.LayerNorm(embed_dim)
+
+ def forward(self, x):
+ """
+ Forward pass - chain self-attention + MLP blocks, with residual connections and layer norms.
+ """
+ x = x + self.nets["attention"](self.nets["ln1"](x))
+ x = x + self.nets["mlp"](self.nets["ln2"](x))
+ return x
+
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+
+ # this module doesn't modify the size of the input, it goes from (B, T, D) -> (B, T, D)
+ return list(input_shape)
+
+
+class GPT_Backbone(Module):
+ """the GPT model, with a context size of block_size"""
+
+ def __init__(
+ self,
+ embed_dim,
+ context_length,
+ attn_dropout=0.1,
+ block_output_dropout=0.1,
+ num_layers=6,
+ num_heads=8,
+ activation="gelu",
+ ):
+ """
+ Args:
+ embed_dim (int): dimension of embeddings to use for keys, queries, and values
+ used in self-attention
+
+ context_length (int): expected length of input sequences
+
+ attn_dropout (float): dropout probability for attention outputs for each transformer block
+
+ block_output_dropout (float): dropout probability for final outputs for each transformer block
+
+ num_layers (int): number of transformer blocks to stack
+
+ num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is
+ computed over this many partitions of the embedding dimension separately.
+
+ activation (str): string denoting the activation function to use in each transformer block
+
+ """
+ super(GPT_Backbone, self).__init__()
+
+ self.embed_dim = embed_dim
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.context_length = context_length
+ self.attn_dropout = attn_dropout
+ self.block_output_dropout = block_output_dropout
+
+ if activation == "gelu":
+ self.activation = nn.GELU()
+ elif activation == "geglu":
+ self.activation = GEGLU()
+
+ # create networks
+ self._create_networks()
+
+ # initialize weights
+ self.apply(self._init_weights)
+
+ print(
+ "Created {} model with number of parameters: {}".format(
+ self.__class__.__name__, sum(p.numel() for p in self.parameters())
+ )
+ )
+
+ def _create_networks(self):
+ """
+ Helper function to create networks.
+ """
+ self.nets = nn.ModuleDict()
+
+ # transformer - cascaded transformer blocks
+ self.nets["transformer"] = nn.Sequential(
+ *[
+ SelfAttentionBlock(
+ embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ context_length=self.context_length,
+ attn_dropout=self.attn_dropout,
+ output_dropout=self.block_output_dropout,
+ activation=self.activation,
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+
+ # decoder head
+ self.nets["output_ln"] = nn.LayerNorm(self.embed_dim)
+
+ def _init_weights(self, module):
+ """
+ Weight initializer.
+ """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+
+ # this module takes inputs (B, T, @self.input_dim) and produces outputs (B, T, @self.output_dim)
+ return input_shape[:-1] + [self.output_dim]
+
+ def forward(self, inputs):
+ assert inputs.shape[1:] == (self.context_length, self.embed_dim), inputs.shape
+ x = self.nets["transformer"](inputs)
+ transformer_output = self.nets["output_ln"](x)
+ return transformer_output
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/models/vae_nets.py b/phantom/submodules/phantom-robomimic/robomimic/models/vae_nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..91b4e7f02352126f17fcc92a0e651080a4e0bee6
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/models/vae_nets.py
@@ -0,0 +1,1386 @@
+"""
+Contains an implementation of Variational Autoencoder (VAE) and other
+variants, including other priors, and RNN-VAEs.
+"""
+import textwrap
+import numpy as np
+from copy import deepcopy
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributions as D
+
+import robomimic.utils.loss_utils as LossUtils
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.torch_utils as TorchUtils
+from robomimic.models.base_nets import Module
+from robomimic.models.obs_nets import MIMO_MLP
+
+
+def vae_args_from_config(vae_config):
+ """
+ Generate a set of VAE args that are read from the VAE-specific part
+ of a config (for example see `config.algo.vae` in BCConfig).
+ """
+ vae_args = dict(
+ encoder_layer_dims=vae_config.encoder_layer_dims,
+ decoder_layer_dims=vae_config.decoder_layer_dims,
+ latent_dim=vae_config.latent_dim,
+ decoder_is_conditioned=vae_config.decoder.is_conditioned,
+ decoder_reconstruction_sum_across_elements=vae_config.decoder.reconstruction_sum_across_elements,
+ latent_clip=vae_config.latent_clip,
+ prior_learn=vae_config.prior.learn,
+ prior_is_conditioned=vae_config.prior.is_conditioned,
+ prior_layer_dims=vae_config.prior_layer_dims,
+ prior_use_gmm=vae_config.prior.use_gmm,
+ prior_gmm_num_modes=vae_config.prior.gmm_num_modes,
+ prior_gmm_learn_weights=vae_config.prior.gmm_learn_weights,
+ prior_use_categorical=vae_config.prior.use_categorical,
+ prior_categorical_dim=vae_config.prior.categorical_dim,
+ prior_categorical_gumbel_softmax_hard=vae_config.prior.categorical_gumbel_softmax_hard,
+ )
+ return vae_args
+
+
+class Prior(Module):
+ """
+ Base class for VAE priors. It's basically the same as a @MIMO_MLP network (it
+ instantiates one) but it supports additional methods such as KL loss computation
+ and sampling, and also may learn prior parameters as observation-independent
+ torch Parameters instead of observation-dependent mappings.
+ """
+ def __init__(
+ self,
+ param_shapes,
+ param_obs_dependent,
+ obs_shapes=None,
+ mlp_layer_dims=(),
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ param_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for parameters that determine the prior
+ distribution.
+
+ param_obs_dependent (OrderedDict): a dictionary with boolean
+ values consistent with @param_shapes which determines whether
+ to learn parameters as part of the (obs-dependent) network or
+ directly as learnable parameters.
+
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layer sizes
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ super(Prior, self).__init__()
+
+ assert isinstance(param_shapes, OrderedDict) and isinstance(param_obs_dependent, OrderedDict)
+ assert set(param_shapes.keys()) == set(param_obs_dependent.keys())
+ self.param_shapes = param_shapes
+ self.param_obs_dependent = param_obs_dependent
+
+ net_kwargs = dict(
+ obs_shapes=obs_shapes,
+ mlp_layer_dims=mlp_layer_dims,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+ self._create_layers(net_kwargs)
+
+ def _create_layers(self, net_kwargs):
+ """
+ Create networks and parameters needed by the prior.
+ """
+ self.prior_params = nn.ParameterDict()
+
+ self._is_obs_dependent = False
+ mlp_output_shapes = OrderedDict()
+ for pp in self.param_shapes:
+ if self.param_obs_dependent[pp]:
+ # prior parameters will be a function of observations using a network
+ mlp_output_shapes[pp] = self.param_shapes[pp]
+ else:
+ # learnable prior parameters independent of observation
+ param_init = torch.randn(*self.param_shapes[pp]) / np.sqrt(np.prod(self.param_shapes[pp]))
+ self.prior_params[pp] = torch.nn.Parameter(param_init)
+
+ # only make networks if we have obs-dependent prior parameters
+ self.prior_module = None
+ if len(mlp_output_shapes) > 0:
+ # create @MIMO_MLP that takes obs and goal dicts and returns prior params
+ self._is_obs_dependent = True
+ obs_shapes = net_kwargs["obs_shapes"]
+ goal_shapes = net_kwargs["goal_shapes"]
+ obs_group_shapes = OrderedDict()
+ assert isinstance(obs_shapes, OrderedDict)
+ obs_group_shapes["obs"] = OrderedDict(obs_shapes)
+ if goal_shapes is not None and len(goal_shapes) > 0:
+ assert isinstance(goal_shapes, OrderedDict)
+ obs_group_shapes["goal"] = OrderedDict(goal_shapes)
+ self.prior_module = MIMO_MLP(
+ input_obs_group_shapes=obs_group_shapes,
+ output_shapes=mlp_output_shapes,
+ layer_dims=net_kwargs["mlp_layer_dims"],
+ encoder_kwargs=net_kwargs["encoder_kwargs"],
+ )
+
+ def sample(self, n, obs_dict=None, goal_dict=None):
+ """
+ Returns a batch of samples from the prior distribution.
+
+ Args:
+ n (int): this argument is used to specify the number
+ of samples to generate from the prior.
+
+ obs_dict (dict): inputs according to @obs_shapes. Only needs to be provided
+ if any prior parameters are obs-dependent. Leading dimension should
+ be consistent with @n, the number of samples to generate.
+
+ goal_dict (dict): inputs according to @goal_shapes (only if using goal observations)
+
+ Returns:
+ z (torch.Tensor): batch of sampled latent vectors.
+ """
+ raise NotImplementedError
+
+ def kl_loss(self, posterior_params, z=None, obs_dict=None, goal_dict=None):
+ """
+ Computes sample-based KL divergence loss between the Gaussian distribution
+ given by @mu, @logvar and the prior distribution.
+
+ Args:
+ posterior_params (dict): dictionary with keys "mu" and "logvar" corresponding
+ to torch.Tensor batch of means and log-variances of posterior Gaussian
+ distribution.
+
+ z (torch.Tensor): samples from the Gaussian distribution parametrized by
+ @mu and @logvar. May not be needed depending on the prior.
+
+ obs_dict (dict): inputs according to @obs_shapes. Only needs to be provided
+ if any prior parameters are obs-dependent.
+
+ goal_dict (dict): inputs according to @goal_shapes (only if using goal observations)
+
+ Returns:
+ kl_loss (torch.Tensor): KL divergence loss
+ """
+ raise NotImplementedError
+
+ def output_shape(self, input_shape=None):
+ """
+ Returns output shape for this module, which is a dictionary instead
+ of a list since outputs are dictionaries.
+ """
+ if self.prior_module is not None:
+ return self.prior_module.output_shape(input_shape)
+ return { k : list(self.param_shapes[k]) for k in self.param_shapes }
+
+ def forward(self, batch_size, obs_dict=None, goal_dict=None):
+ """
+ Computes prior parameters.
+
+ Args:
+ batch_size (int): batch size - this is needed for parameters that are
+ not obs-dependent, to make sure the leading dimension is correct
+ for downstream sampling and loss computation purposes
+
+ obs_dict (dict): inputs according to @obs_shapes. Only needs to be provided
+ if any prior parameters are obs-dependent.
+
+ goal_dict (dict): inputs according to @goal_shapes (only if using goal observations)
+
+ Returns:
+ prior_params (dict): dictionary containing prior parameters
+ """
+ prior_params = dict()
+ if self._is_obs_dependent:
+ # forward through network for obs-dependent params
+ prior_params = self.prior_module.forward(obs=obs_dict, goal=goal_dict)
+
+ # return params that do not depend on obs as well
+ for pp in self.param_shapes:
+ if not self.param_obs_dependent[pp]:
+ # ensure leading dimension will be consistent with other params
+ prior_params[pp] = TensorUtils.expand_at(self.prior_params[pp], size=batch_size, dim=0)
+
+ # ensure leading dimensions are all consistent
+ TensorUtils.assert_size_at_dim(prior_params, size=batch_size, dim=0,
+ msg="prior params dim 0 mismatch in forward")
+
+ return prior_params
+
+
+class GaussianPrior(Prior):
+ """
+ A class that holds functionality for learning both unimodal Gaussian priors and
+ multimodal Gaussian Mixture Model priors for use in VAEs.
+ """
+ def __init__(
+ self,
+ latent_dim,
+ device,
+ latent_clip=None,
+ learnable=False,
+ use_gmm=False,
+ gmm_num_modes=10,
+ gmm_learn_weights=False,
+ obs_shapes=None,
+ mlp_layer_dims=(),
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ latent_dim (int): size of latent dimension for the prior
+
+ device (torch.Device): where the module should live (i.e. cpu, gpu)
+
+ latent_clip (float): if provided, clip all latents sampled at
+ test-time in each dimension to (-@latent_clip, @latent_clip)
+
+ learnable (bool): if True, learn the parameters of the prior (as opposed
+ to a default N(0, 1) prior)
+
+ use_gmm (bool): if True, learn a Gaussian Mixture Model (GMM)
+ prior instead of a unimodal Gaussian prior. To use this option,
+ @learnable must be set to True.
+
+ gmm_num_modes (int): number of GMM modes to learn. Only
+ used if @use_gmm is True.
+
+ gmm_learn_weights (bool): if True, learn the weights of the GMM
+ model instead of setting them to be uniform across all the modes.
+ Only used if @use_gmm is True.
+
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations. If provided, assumes that
+ the prior should depend on observation inputs, and networks
+ will be created to output prior parameters.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layer sizes
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ self.device = device
+ self.latent_dim = latent_dim
+ self.latent_clip = latent_clip
+ self.learnable = learnable
+
+ self.use_gmm = use_gmm
+ if self.use_gmm:
+ self.num_modes = gmm_num_modes
+ else:
+ # unimodal Gaussian prior
+ self.num_modes = 1
+ self.gmm_learn_weights = gmm_learn_weights
+
+ self._input_dependent = (obs_shapes is not None) and (len(obs_shapes) > 0)
+
+ if self._input_dependent:
+ assert learnable
+ assert isinstance(obs_shapes, OrderedDict)
+
+ # network will generate mean and logvar
+ param_shapes = OrderedDict(
+ mean=(self.num_modes, self.latent_dim,),
+ logvar=(self.num_modes, self.latent_dim,),
+ )
+ param_obs_dependent = OrderedDict(mean=True, logvar=True)
+
+ if self.use_gmm and self.gmm_learn_weights:
+ # network generates GMM weights
+ param_shapes["weight"] = (self.num_modes,)
+ param_obs_dependent["weight"] = True
+ else:
+ # learn obs-indep mean / logvar
+ param_shapes = OrderedDict(
+ mean=(1, self.num_modes, self.latent_dim),
+ logvar=(1, self.num_modes, self.latent_dim),
+ )
+ param_obs_dependent = OrderedDict(mean=False, logvar=False)
+
+ if self.use_gmm and self.gmm_learn_weights:
+ # learn obs-indep GMM weights
+ param_shapes["weight"] = (1, self.num_modes)
+ param_obs_dependent["weight"] = False
+
+ super(GaussianPrior, self).__init__(
+ param_shapes=param_shapes,
+ param_obs_dependent=param_obs_dependent,
+ obs_shapes=obs_shapes,
+ mlp_layer_dims=mlp_layer_dims,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def _create_layers(self, net_kwargs):
+ """
+ Update from superclass to only create parameters / networks if not using
+ N(0, 1) Gaussian prior.
+ """
+ if self.learnable:
+ super(GaussianPrior, self)._create_layers(net_kwargs)
+
+ def sample(self, n, obs_dict=None, goal_dict=None):
+ """
+ Returns a batch of samples from the prior distribution.
+
+ Args:
+ n (int): this argument is used to specify the number
+ of samples to generate from the prior.
+
+ obs_dict (dict): inputs according to @obs_shapes. Only needs to be provided
+ if any prior parameters are obs-dependent. Leading dimension should
+ be consistent with @n, the number of samples to generate.
+
+ goal_dict (dict): inputs according to @goal_shapes (only if using goal observations)
+
+ Returns:
+ z (torch.Tensor): batch of sampled latent vectors.
+ """
+
+ # check consistency between n and obs_dict
+ if self._input_dependent:
+ TensorUtils.assert_size_at_dim(obs_dict, size=n, dim=0,
+ msg="obs dict and n mismatch in @sample")
+
+ if self.learnable:
+
+ # forward to get parameters
+ out = self.forward(batch_size=n, obs_dict=obs_dict, goal_dict=goal_dict)
+ prior_means, prior_logvars, prior_logweights = out["means"], out["logvars"], out["logweights"]
+
+ if prior_logweights is not None:
+ prior_weights = torch.exp(prior_logweights)
+
+ if self.use_gmm:
+ # learned GMM
+
+ # make uniform weights (in the case that weights were not learned)
+ if not self.gmm_learn_weights:
+ prior_weights = torch.ones(n, self.num_modes).to(prior_means.device) / self.num_modes
+
+ # sample modes
+ gmm_mode_indices = D.Categorical(prior_weights).sample()
+
+ # get GMM centers and sample using reparametrization trick
+ selected_means = TensorUtils.gather_sequence(prior_means, indices=gmm_mode_indices)
+ selected_logvars = TensorUtils.gather_sequence(prior_logvars, indices=gmm_mode_indices)
+ z = TorchUtils.reparameterize(selected_means, selected_logvars)
+
+ else:
+ # learned unimodal Gaussian - remove mode dim and sample from Gaussian using reparametrization trick
+ z = TorchUtils.reparameterize(prior_means[:, 0, :], prior_logvars[:, 0, :])
+
+ else:
+ # sample from N(0, 1)
+ z = torch.randn(n, self.latent_dim).float().to(self.device)
+
+ if self.latent_clip is not None:
+ z = z.clamp(-self.latent_clip, self.latent_clip)
+
+ return z
+
+ def kl_loss(self, posterior_params, z=None, obs_dict=None, goal_dict=None):
+ """
+ Computes sample-based KL divergence loss between the Gaussian distribution
+ given by @mu, @logvar and the prior distribution.
+
+ Args:
+ posterior_params (dict): dictionary with keys "mu" and "logvar" corresponding
+ to torch.Tensor batch of means and log-variances of posterior Gaussian
+ distribution.
+
+ z (torch.Tensor): samples from the Gaussian distribution parametrized by
+ @mu and @logvar. Only needed if @self.use_gmm is True.
+
+ obs_dict (dict): inputs according to @obs_shapes. Only needs to be provided
+ if any prior parameters are obs-dependent.
+
+ goal_dict (dict): inputs according to @goal_shapes (only if using goal observations)
+
+ Returns:
+ kl_loss (torch.Tensor): KL divergence loss
+ """
+ mu = posterior_params["mean"]
+ logvar = posterior_params["logvar"]
+
+ if not self.learnable:
+ # closed-form Gaussian KL from N(0, 1) prior
+ return LossUtils.KLD_0_1_loss(mu=mu, logvar=logvar)
+
+ # forward to get parameters
+ out = self.forward(batch_size=mu.shape[0], obs_dict=obs_dict, goal_dict=goal_dict)
+ prior_means, prior_logvars, prior_logweights = out["means"], out["logvars"], out["logweights"]
+
+ if not self.use_gmm:
+ # collapse mode dimension and compute Gaussian KL in closed-form
+ prior_means = prior_means[:, 0, :]
+ prior_logvars = prior_logvars[:, 0, :]
+ return LossUtils.KLD_gaussian_loss(
+ mu_1=mu,
+ logvar_1=logvar,
+ mu_2=prior_means,
+ logvar_2=prior_logvars,
+ )
+
+ # GMM KL loss computation
+ var = torch.exp(logvar.clamp(-8, 30)) # clamp for numerical stability
+ prior_vars = torch.exp(prior_logvars.clamp(-8, 30))
+ kl_loss = LossUtils.log_normal(x=z, m=mu, v=var) \
+ - LossUtils.log_normal_mixture(x=z, m=prior_means, v=prior_vars, log_w=prior_logweights)
+ return kl_loss.mean()
+
+ def forward(self, batch_size, obs_dict=None, goal_dict=None):
+ """
+ Computes means, logvars, and GMM weights (if using GMM and learning weights).
+
+ Args:
+ batch_size (int): batch size - this is needed for parameters that are
+ not obs-dependent, to make sure the leading dimension is correct
+ for downstream sampling and loss computation purposes
+
+ obs_dict (dict): inputs according to @obs_shapes. Only needs to be provided
+ if any prior parameters are obs-dependent.
+
+ goal_dict (dict): inputs according to @goal_shapes (only if using goal observations)
+
+ Returns:
+ prior_params (dict): dictionary containing prior parameters
+ """
+ assert self.learnable
+ prior_params = super(GaussianPrior, self).forward(
+ batch_size=batch_size, obs_dict=obs_dict, goal_dict=goal_dict)
+
+ if self.use_gmm and self.gmm_learn_weights:
+ # normalize learned weight outputs to sum to 1
+ logweights = F.log_softmax(prior_params["weight"], dim=-1)
+ else:
+ logweights = None
+ assert "weight" not in prior_params
+
+ out = dict(means=prior_params["mean"], logvars=prior_params["logvar"], logweights=logweights)
+ return out
+
+ def __repr__(self):
+ """Pretty print network"""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ indent = ' ' * 4
+ msg += textwrap.indent("latent_dim={}\n".format(self.latent_dim), indent)
+ msg += textwrap.indent("latent_clip={}\n".format(self.latent_clip), indent)
+ msg += textwrap.indent("learnable={}\n".format(self.learnable), indent)
+ msg += textwrap.indent("input_dependent={}\n".format(self._input_dependent), indent)
+ msg += textwrap.indent("use_gmm={}\n".format(self.use_gmm), indent)
+ if self.use_gmm:
+ msg += textwrap.indent("gmm_num_nodes={}\n".format(self.num_modes), indent)
+ msg += textwrap.indent("gmm_learn_weights={}\n".format(self.gmm_learn_weights), indent)
+ if self.learnable:
+ if self.prior_module is not None:
+ msg += textwrap.indent("\nprior_module={}\n".format(self.prior_module), indent)
+ msg += textwrap.indent("prior_params={}\n".format(self.prior_params), indent)
+ msg = header + '(\n' + msg + ')'
+ return msg
+
+
+class CategoricalPrior(Prior):
+ """
+ A class that holds functionality for learning categorical priors for use
+ in VAEs.
+ """
+ def __init__(
+ self,
+ latent_dim,
+ categorical_dim,
+ device,
+ learnable=False,
+ obs_shapes=None,
+ mlp_layer_dims=(),
+ goal_shapes=None,
+ encoder_kwargs=None,
+
+ ):
+ """
+ Args:
+ latent_dim (int): size of latent dimension for the prior
+
+ categorical_dim (int): size of categorical dimension (number of classes
+ for each dimension of latent space)
+
+ device (torch.Device): where the module should live (i.e. cpu, gpu)
+
+ learnable (bool): if True, learn the parameters of the prior (as opposed
+ to a default N(0, 1) prior)
+
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations. If provided, assumes that
+ the prior should depend on observation inputs, and networks
+ will be created to output prior parameters.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layer sizes
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ self.device = device
+ self.latent_dim = latent_dim
+ self.categorical_dim = categorical_dim
+ self.learnable = learnable
+
+ self._input_dependent = (obs_shapes is not None) and (len(obs_shapes) > 0)
+
+ if self._input_dependent:
+ assert learnable
+ assert isinstance(obs_shapes, OrderedDict)
+
+ # network will generate logits for categorical distributions
+ param_shapes = OrderedDict(
+ logit=(self.latent_dim, self.categorical_dim,)
+ )
+ param_obs_dependent = OrderedDict(logit=True)
+ else:
+ # learn obs-indep mean / logvar
+ param_shapes = OrderedDict(
+ logit=(1, self.latent_dim, self.categorical_dim),
+ )
+ param_obs_dependent = OrderedDict(logit=False)
+
+ super(CategoricalPrior, self).__init__(
+ param_shapes=param_shapes,
+ param_obs_dependent=param_obs_dependent,
+ obs_shapes=obs_shapes,
+ mlp_layer_dims=mlp_layer_dims,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def _create_layers(self, net_kwargs):
+ """
+ Update from superclass to only create parameters / networks if not using
+ uniform categorical prior.
+ """
+ if self.learnable:
+ super(CategoricalPrior, self)._create_layers(net_kwargs)
+
+ def sample(self, n, obs_dict=None, goal_dict=None):
+ """
+ Returns a batch of samples from the prior distribution.
+
+ Args:
+ n (int): this argument is used to specify the number
+ of samples to generate from the prior.
+
+ obs_dict (dict): inputs according to @obs_shapes. Only needs to be provided
+ if any prior parameters are obs-dependent. Leading dimension should
+ be consistent with @n, the number of samples to generate.
+
+ goal_dict (dict): inputs according to @goal_shapes (only if using goal observations)
+
+ Returns:
+ z (torch.Tensor): batch of sampled latent vectors.
+ """
+
+ # check consistency between n and obs_dict
+ if self._input_dependent:
+ TensorUtils.assert_size_at_dim(obs_dict, size=n, dim=0,
+ msg="obs dict and n mismatch in @sample")
+
+ if self.learnable:
+
+ # forward to get parameters
+ out = self.forward(batch_size=n, obs_dict=obs_dict, goal_dict=goal_dict)
+ prior_logits = out["logit"]
+
+ # sample one-hot latents from categorical distribution
+ dist = D.Categorical(logits=prior_logits)
+ z = TensorUtils.to_one_hot(dist.sample(), num_class=self.categorical_dim)
+
+ else:
+ # try to include a categorical sample for each class if possible (ensuring rough uniformity)
+ if (self.latent_dim == 1) and (self.categorical_dim <= n):
+ # include samples [0, 1, ..., C - 1] and then repeat until batch is filled
+ dist_samples = torch.arange(n).remainder(self.categorical_dim).unsqueeze(-1).to(self.device)
+ else:
+ # sample one-hot latents from uniform categorical distribution for each latent dimension
+ probs = torch.ones(n, self.latent_dim, self.categorical_dim).float().to(self.device)
+ dist_samples = D.Categorical(probs=probs).sample()
+ z = TensorUtils.to_one_hot(dist_samples, num_class=self.categorical_dim)
+
+ # reshape [B, D, C] to [B, D * C] to be consistent with other priors that return flat latents
+ z = z.reshape(*z.shape[:-2], -1)
+ return z
+
+ def kl_loss(self, posterior_params, z=None, obs_dict=None, goal_dict=None):
+ """
+ Computes KL divergence loss between the Categorical distribution
+ given by the unnormalized logits @logits and the prior distribution.
+
+ Args:
+ posterior_params (dict): dictionary with key "logits" corresponding
+ to torch.Tensor batch of unnormalized logits of shape [B, D * C]
+ that corresponds to the posterior categorical distribution
+
+ z (torch.Tensor): samples from encoder - unused for this prior
+
+ obs_dict (dict): inputs according to @obs_shapes. Only needs to be provided
+ if any prior parameters are obs-dependent.
+
+ goal_dict (dict): inputs according to @goal_shapes (only if using goal observations)
+
+ Returns:
+ kl_loss (torch.Tensor): KL divergence loss
+ """
+ logits = posterior_params["logit"].reshape(-1, self.latent_dim, self.categorical_dim)
+ if not self.learnable:
+ # prior logits correspond to uniform categorical distribution
+ prior_logits = torch.zeros_like(logits)
+ else:
+ # forward to get parameters
+ out = self.forward(batch_size=posterior_params["logit"].shape[0], obs_dict=obs_dict, goal_dict=goal_dict)
+ prior_logits = out["logit"]
+
+ prior_dist = D.Categorical(logits=prior_logits)
+ posterior_dist = D.Categorical(logits=logits)
+
+ # sum over latent dimensions, but average over batch dimension
+ kl_loss = D.kl_divergence(posterior_dist, prior_dist)
+ assert len(kl_loss.shape) == 2
+ return kl_loss.sum(-1).mean()
+
+ def forward(self, batch_size, obs_dict=None, goal_dict=None):
+ """
+ Computes prior logits (unnormalized log-probs).
+
+ Args:
+ batch_size (int): batch size - this is needed for parameters that are
+ not obs-dependent, to make sure the leading dimension is correct
+ for downstream sampling and loss computation purposes
+
+ obs_dict (dict): inputs according to @obs_shapes. Only needs to be provided
+ if any prior parameters are obs-dependent.
+
+ goal_dict (dict): inputs according to @goal_shapes (only if using goal observations)
+
+ Returns:
+ prior_params (dict): dictionary containing prior parameters
+ """
+ assert self.learnable
+ return super(CategoricalPrior, self).forward(
+ batch_size=batch_size, obs_dict=obs_dict, goal_dict=goal_dict)
+
+ def __repr__(self):
+ """Pretty print network"""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = ''
+ indent = ' ' * 4
+ msg += textwrap.indent("latent_dim={}\n".format(self.latent_dim), indent)
+ msg += textwrap.indent("categorical_dim={}\n".format(self.categorical_dim), indent)
+ msg += textwrap.indent("learnable={}\n".format(self.learnable), indent)
+ msg += textwrap.indent("input_dependent={}\n".format(self._input_dependent), indent)
+ if self.learnable:
+ if self.prior_module is not None:
+ msg += textwrap.indent("\nprior_module={}\n".format(self.prior_module), indent)
+ msg += textwrap.indent("prior_params={}\n".format(self.prior_params), indent)
+ msg = header + '(\n' + msg + ')'
+ return msg
+
+
+class VAE(torch.nn.Module):
+ """
+ A Variational Autoencoder (VAE), as described in https://arxiv.org/abs/1312.6114.
+
+ Models a distribution p(X) or a conditional distribution p(X | Y), where each
+ variable can consist of multiple modalities. The target variable X, whose
+ distribution is modeled, is specified through the @input_shapes argument,
+ which is a map between modalities (strings) and expected shapes. In this way,
+ a variable that consists of multiple kinds of data (e.g. image and flat-dimensional)
+ can be modeled as well. A separate @output_shapes argument is used to specify the
+ expected reconstructions - this allows for asymmetric reconstruction (for example,
+ reconstructing low-resolution images).
+
+ This implementation supports learning conditional distributions as well (cVAE).
+ The conditioning variable Y is specified through the @condition_shapes argument,
+ which is also a map between modalities (strings) and expected shapes. In this way,
+ variables with multiple kinds of data (e.g. image and flat-dimensional) can
+ jointly be conditioned on. By default, the decoder takes the conditioning
+ variable Y as input. To force the decoder to reconstruct from just the latent,
+ set @decoder_is_conditioned to False (in this case, the prior must be conditioned).
+
+ The implementation also supports learning expressive priors instead of using
+ the usual N(0, 1) prior. There are three kinds of priors supported - Gaussian,
+ Gaussian Mixture Model (GMM), and Categorical. For each prior, the parameters can
+ be learned as independent parameters, or be learned as functions of the conditioning
+ variable Y (by setting @prior_is_conditioned).
+ """
+ def __init__(
+ self,
+ input_shapes,
+ output_shapes,
+ encoder_layer_dims,
+ decoder_layer_dims,
+ latent_dim,
+ device,
+ condition_shapes=None,
+ decoder_is_conditioned=True,
+ decoder_reconstruction_sum_across_elements=False,
+ latent_clip=None,
+ output_squash=(),
+ output_scales=None,
+ output_ranges=None,
+ prior_learn=False,
+ prior_is_conditioned=False,
+ prior_layer_dims=(),
+ prior_use_gmm=False,
+ prior_gmm_num_modes=10,
+ prior_gmm_learn_weights=False,
+ prior_use_categorical=False,
+ prior_categorical_dim=10,
+ prior_categorical_gumbel_softmax_hard=False,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ input_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for all encoder-specific inputs. This corresponds
+ to the variable X whose distribution we are learning.
+
+ output_shapes (OrderedDict): a dictionary that maps modality to
+ expected shape for outputs to reconstruct. Usually, this is
+ the same as @input_shapes but this argument allows
+ for asymmetries, such as reconstructing low-resolution
+ images.
+
+ encoder_layer_dims ([int]): sequence of integers for the encoder hidden
+ layer sizes.
+
+ decoder_layer_dims ([int]): sequence of integers for the decoder hidden
+ layer sizes.
+
+ latent_dim (int): dimension of latent space for the VAE
+
+ device (torch.Device): where the module should live (i.e. cpu, gpu)
+
+ condition_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for all conditioning inputs. If this is provided,
+ a conditional distribution is modeled (cVAE). Conditioning takes
+ place in the decoder by default, and optionally, the prior.
+
+ decoder_is_conditioned (bool): whether to condition the decoder
+ on the conditioning variables. True by default. Only used if
+ @condition_shapes is not empty.
+
+ decoder_reconstruction_sum_across_elements (bool): by default, VAEs
+ average across modality elements and modalities when computing
+ reconstruction loss. If this is True, sum across all dimensions
+ and modalities instead.
+
+ latent_clip (float): if provided, clip all latents sampled at
+ test-time in each dimension to (-@latent_clip, @latent_clip)
+
+ output_squash ([str]): an iterable of modalities that should be
+ a subset of @output_shapes. The decoder outputs for these
+ modalities will be squashed into a symmetric range [-a, a]
+ by using a tanh layer and then scaling the output with the
+ corresponding value in the @output_scales dictionary.
+
+ output_scales (dict): a dictionary that maps modality to a
+ scaling value. Used in conjunction with @output_squash.
+
+ output_ranges (dict): a dictionary of [a, b] specifying the output range.
+ when output_ranges is specified (not None), output_scales should be None
+
+ prior_learn (bool): if True, the prior distribution parameters
+ are also learned through the KL-divergence loss (instead
+ of being constrained to a N(0, 1) Gaussian distribution).
+ If @prior_is_conditioned is True, a global set of parameters
+ are learned, otherwise, a prior network that maps between
+ modalities in @condition_shapes and prior parameters is
+ learned. By default, a Gaussian prior is learned, unless
+ @prior_use_gmm is True, in which case a Gaussian Mixture
+ Model (GMM) prior is learned.
+
+ prior_is_conditioned (bool): whether to condition the prior
+ on the conditioning variables. False by default. Only used if
+ @condition_shapes is not empty. If this is set to True,
+ @prior_learn must be True.
+
+ prior_layer_dims ([int]): sequence of integers for the prior hidden layer
+ sizes. Only used for learned priors that take condition variables as
+ input (i.e. when @prior_learn and @prior_is_conditioned are set to True,
+ and @condition_shapes is not empty).
+
+ prior_use_gmm (bool): if True, learn a Gaussian Mixture Model (GMM)
+ prior instead of a unimodal Gaussian prior. To use this option,
+ @prior_learn must be set to True.
+
+ prior_gmm_num_modes (int): number of GMM modes to learn. Only
+ used if @prior_use_gmm is True.
+
+ prior_gmm_learn_weights (bool): if True, learn the weights of the GMM
+ model instead of setting them to be uniform across all the modes.
+ Only used if @prior_use_gmm is True.
+
+ prior_use_categorical (bool): if True, use a categorical prior instead of
+ a unimodal Gaussian prior. This will also cause the encoder to output
+ a categorical distribution, and will use the Gumbel-Softmax trick
+ for reparametrization.
+
+ prior_categorical_dim (int): categorical dimension - each latent sampled
+ from the prior will be of shape (@latent_dim, @prior_categorical_dim)
+ and will be "one-hot" in the latter dimension. Only used if
+ @prior_use_categorical is True.
+
+ prior_categorical_gumbel_softmax_hard (bool): if True, use the "hard" version of
+ Gumbel Softmax for reparametrization. Only used if @prior_use_categorical is True.
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations. Goals are treates as additional
+ conditioning inputs. They are usually specified separately because
+ they have duplicate modalities as the conditioning inputs (otherwise
+ they could just be added to the set of conditioning inputs).
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ super(VAE, self).__init__()
+
+ self.latent_dim = latent_dim
+ self.latent_clip = latent_clip
+ self.device = device
+
+ # encoder and decoder input dicts and output shapes dict for reconstruction
+ assert isinstance(input_shapes, OrderedDict)
+ assert isinstance(output_shapes, OrderedDict)
+ self.input_shapes = deepcopy(input_shapes)
+ self.output_shapes = deepcopy(output_shapes)
+
+ # check for conditioning (cVAE)
+ self._is_cvae = False
+ self.condition_shapes = deepcopy(condition_shapes) if condition_shapes is not None else OrderedDict()
+ if len(self.condition_shapes) > 0:
+ # this is a cVAE - we learn a conditional distribution p(X | Y)
+ assert isinstance(self.condition_shapes, OrderedDict)
+ self._is_cvae = True
+ self.decoder_is_conditioned = decoder_is_conditioned
+ self.prior_is_conditioned = prior_is_conditioned
+ assert self.decoder_is_conditioned or self.prior_is_conditioned, \
+ "cVAE must be conditioned in decoder and/or prior"
+ if self.prior_is_conditioned:
+ assert prior_learn, "to pass conditioning inputs to prior, prior must be learned"
+
+ # check for goal conditioning
+ self._is_goal_conditioned = False
+ self.goal_shapes = deepcopy(goal_shapes) if goal_shapes is not None else OrderedDict()
+ if len(self.goal_shapes) > 0:
+ assert self._is_cvae, "to condition VAE on goals, it must be a cVAE"
+ assert isinstance(self.goal_shapes, OrderedDict)
+ self._is_goal_conditioned = True
+
+ self.encoder_layer_dims = encoder_layer_dims
+ self.decoder_layer_dims = decoder_layer_dims
+
+ # determines whether outputs are squashed with tanh and if so, to what scaling
+ assert not (output_scales is not None and output_ranges is not None)
+ self.output_squash = output_squash
+ self.output_scales = output_scales if output_scales is not None else OrderedDict()
+ self.output_ranges = output_ranges if output_ranges is not None else OrderedDict()
+
+ assert set(self.output_squash) == set(self.output_scales.keys())
+ assert set(self.output_squash).issubset(set(self.output_shapes))
+
+ # decoder settings
+ self.decoder_reconstruction_sum_across_elements = decoder_reconstruction_sum_across_elements
+
+ # prior parameters
+ self.prior_learn = prior_learn
+ self.prior_layer_dims = prior_layer_dims
+ self.prior_use_gmm = prior_use_gmm
+ self.prior_gmm_num_modes = prior_gmm_num_modes
+ self.prior_gmm_learn_weights = prior_gmm_learn_weights
+ self.prior_use_categorical = prior_use_categorical
+ self.prior_categorical_dim = prior_categorical_dim
+ self.prior_categorical_gumbel_softmax_hard = prior_categorical_gumbel_softmax_hard
+ assert np.sum([self.prior_use_gmm, self.prior_use_categorical]) <= 1
+
+ # for obs core
+ self._encoder_kwargs = encoder_kwargs
+
+ if self.prior_use_gmm:
+ assert self.prior_learn, "GMM must be learned"
+
+ if self.prior_use_categorical:
+ # initialize temperature for Gumbel-Softmax
+ self.set_gumbel_temperature(1.0)
+
+ # create encoder, decoder, prior
+ self._create_layers()
+
+ def _create_layers(self):
+ """
+ Creates the encoder, decoder, and prior networks.
+ """
+ self.nets = nn.ModuleDict()
+
+ # VAE Encoder
+ self._create_encoder()
+
+ # VAE Decoder
+ self._create_decoder()
+
+ # VAE Prior.
+ self._create_prior()
+
+ def _create_encoder(self):
+ """
+ Helper function to create encoder.
+ """
+
+ # encoder takes "input" dictionary and possibly "condition" (if cVAE) and "goal" (if goal-conditioned)
+ encoder_obs_group_shapes = OrderedDict()
+ encoder_obs_group_shapes["input"] = OrderedDict(self.input_shapes)
+ if self._is_cvae:
+ encoder_obs_group_shapes["condition"] = OrderedDict(self.condition_shapes)
+ if self._is_goal_conditioned:
+ encoder_obs_group_shapes["goal"] = OrderedDict(self.goal_shapes)
+
+ # encoder outputs posterior distribution parameters
+ if self.prior_use_categorical:
+ encoder_output_shapes = OrderedDict(
+ logit=(self.latent_dim * self.prior_categorical_dim,),
+ )
+ else:
+ encoder_output_shapes = OrderedDict(
+ mean=(self.latent_dim,),
+ logvar=(self.latent_dim,),
+ )
+
+ self.nets["encoder"] = MIMO_MLP(
+ input_obs_group_shapes=encoder_obs_group_shapes,
+ output_shapes=encoder_output_shapes,
+ layer_dims=self.encoder_layer_dims,
+ encoder_kwargs=self._encoder_kwargs,
+ )
+
+ def _create_decoder(self):
+ """
+ Helper function to create decoder.
+ """
+
+ # decoder takes latent (included as "input" observation group) and possibly "condition" (if cVAE) and "goal" (if goal-conditioned)
+ decoder_obs_group_shapes = OrderedDict()
+ latent_shape = (self.latent_dim,)
+ if self.prior_use_categorical:
+ latent_shape = (self.latent_dim * self.prior_categorical_dim,)
+ decoder_obs_group_shapes["input"] = OrderedDict(latent=latent_shape)
+ if self._is_cvae:
+ decoder_obs_group_shapes["condition"] = OrderedDict(self.condition_shapes)
+ if self._is_goal_conditioned:
+ decoder_obs_group_shapes["goal"] = OrderedDict(self.goal_shapes)
+
+ self.nets["decoder"] = MIMO_MLP(
+ input_obs_group_shapes=decoder_obs_group_shapes,
+ output_shapes=self.output_shapes,
+ layer_dims=self.decoder_layer_dims,
+ encoder_kwargs=self._encoder_kwargs,
+ )
+
+ def _create_prior(self):
+ """
+ Helper function to create prior.
+ """
+
+ # prior possibly takes "condition" (if cVAE) and "goal" (if goal-conditioned)
+ prior_obs_group_shapes = OrderedDict(condition=None, goal=None)
+ if self._is_cvae and self.prior_is_conditioned:
+ prior_obs_group_shapes["condition"] = OrderedDict(self.condition_shapes)
+ if self._is_goal_conditioned:
+ prior_obs_group_shapes["goal"] = OrderedDict(self.goal_shapes)
+
+ if self.prior_use_categorical:
+ self.nets["prior"] = CategoricalPrior(
+ latent_dim=self.latent_dim,
+ categorical_dim=self.prior_categorical_dim,
+ device=self.device,
+ learnable=self.prior_learn,
+ obs_shapes=prior_obs_group_shapes["condition"],
+ mlp_layer_dims=self.prior_layer_dims,
+ goal_shapes=prior_obs_group_shapes["goal"],
+ encoder_kwargs=self._encoder_kwargs,
+ )
+ else:
+ self.nets["prior"] = GaussianPrior(
+ latent_dim=self.latent_dim,
+ device=self.device,
+ latent_clip=self.latent_clip,
+ learnable=self.prior_learn,
+ use_gmm=self.prior_use_gmm,
+ gmm_num_modes=self.prior_gmm_num_modes,
+ gmm_learn_weights=self.prior_gmm_learn_weights,
+ obs_shapes=prior_obs_group_shapes["condition"],
+ mlp_layer_dims=self.prior_layer_dims,
+ goal_shapes=prior_obs_group_shapes["goal"],
+ encoder_kwargs=self._encoder_kwargs,
+ )
+
+ def encode(self, inputs, conditions=None, goals=None):
+ """
+ Args:
+ inputs (dict): a dictionary that maps input modalities to torch.Tensor
+ batches. These should correspond to the encoder-only modalities
+ (i.e. @self.encoder_only_shapes).
+
+ conditions (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to the modalities used for conditioning
+ in either the decoder or the prior (or both). Only for cVAEs.
+
+ goals (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to goal modalities. Only for cVAEs.
+
+ Returns:
+ posterior params (dict): dictionary with posterior parameters
+ """
+ return self.nets["encoder"](
+ input=inputs,
+ condition=conditions,
+ goal=goals,
+ )
+
+ def reparameterize(self, posterior_params):
+ """
+ Args:
+ posterior params (dict): dictionary from encoder forward pass that
+ parametrizes the encoder distribution
+
+ Returns:
+ z (torch.Tensor): sampled latents that are also differentiable
+ """
+ if self.prior_use_categorical:
+ # reshape to [B, D, C] to take softmax across categorical classes
+ logits = posterior_params["logit"].reshape(-1, self.latent_dim, self.prior_categorical_dim)
+ z = F.gumbel_softmax(
+ logits=logits,
+ tau=self._gumbel_temperature,
+ hard=self.prior_categorical_gumbel_softmax_hard,
+ dim=-1,
+ )
+ # reshape to [B, D * C], since downstream networks expect flat latents
+ return TensorUtils.flatten(z)
+
+ return TorchUtils.reparameterize(
+ mu=posterior_params["mean"],
+ logvar=posterior_params["logvar"],
+ )
+
+ def decode(self, conditions=None, goals=None, z=None, n=None):
+ """
+ Pass latents through decoder. Latents should be passed in to
+ this function at train-time for backpropagation, but they
+ can be left out at test-time. In this case, latents will
+ be sampled using the VAE prior.
+
+ Args:
+ conditions (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to the modalities used for conditioning
+ in either the decoder or the prior (or both). Only for cVAEs.
+
+ goals (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to goal modalities. Only for cVAEs.
+
+ z (torch.Tensor): if provided, these latents are used to generate
+ reconstructions from the VAE, and the prior is not sampled.
+
+ n (int): this argument is used to specify the number of samples to
+ generate from the prior. Only required if @z is None - i.e.
+ sampling takes place
+
+ Returns:
+ recons (dict): dictionary of reconstructed inputs
+ """
+
+ if z is None:
+ # sample latents from prior distribution
+ assert n is not None
+ z = self.sample_prior(n=n, conditions=conditions, goals=goals)
+
+ # decoder takes latents as input, and maybe condition variables
+ # and goal variables
+ inputs = dict(
+ input=dict(latent=z),
+ condition=conditions,
+ goal=goals,
+ )
+
+ # pass through decoder to reconstruct variables in @self.output_shapes
+ recons = self.nets["decoder"](**inputs)
+
+ # apply tanh squashing to output modalities
+ for k in self.output_squash:
+ recons[k] = self.output_scales[k] * torch.tanh(recons[k])
+
+ for k, v_range in self.output_ranges.items():
+ assert v_range[1] > v_range[0]
+ recons[k] = torch.sigmoid(recons[k]) * (v_range[1] - v_range[0]) + v_range[0]
+ return recons
+
+ def sample_prior(self, n, conditions=None, goals=None):
+ """
+ Samples from the prior using the prior parameters.
+
+ Args:
+ n (int): this argument is used to specify the number
+ of samples to generate from the prior.
+
+ conditions (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to the modalities used for conditioning
+ in either the decoder or the prior (or both). Only for cVAEs.
+
+ goals (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to goal modalities. Only for cVAEs.
+
+ Returns:
+ z (torch.Tensor): sampled latents from the prior
+ """
+ return self.nets["prior"].sample(n=n, obs_dict=conditions, goal_dict=goals)
+
+ def kl_loss(self, posterior_params, encoder_z=None, conditions=None, goals=None):
+ """
+ Computes KL divergence loss given the results of the VAE encoder forward
+ pass and the conditioning and goal modalities (if the prior is input-dependent).
+
+ Args:
+ posterior_params (dict): dictionary with keys "mu" and "logvar" corresponding
+ to torch.Tensor batch of means and log-variances of posterior Gaussian
+ distribution. This is the output of @self.encode.
+
+ encoder_z (torch.Tensor): samples from the Gaussian distribution parametrized by
+ @mu and @logvar. Only required if using a GMM prior.
+
+ conditions (dict): inputs according to @self.condition_shapes. Only needs to be provided
+ if any prior parameters are input-dependent.
+
+ goal_dict (dict): inputs according to @self.goal_shapes (only if using goal observations)
+
+ Returns:
+ kl_loss (torch.Tensor): VAE KL divergence loss
+ """
+ return self.nets["prior"].kl_loss(
+ posterior_params=posterior_params,
+ z=encoder_z,
+ obs_dict=conditions,
+ goal_dict=goals,
+ )
+
+ def reconstruction_loss(self, reconstructions, targets):
+ """
+ Reconstruction loss. Note that we compute the average per-dimension error
+ in each modality and then average across all the modalities.
+
+ The beta term for weighting between reconstruction and kl losses will
+ need to be tuned in practice for each situation (see
+ https://twitter.com/memotv/status/973323454350090240 for more
+ discussion).
+
+ Args:
+ reconstructions (dict): reconstructed inputs, consistent with
+ @self.output_shapes
+ targets (dict): reconstruction targets, consistent with
+ @self.output_shapes
+
+ Returns:
+ reconstruction_loss (torch.Tensor): VAE reconstruction loss
+ """
+ random_key = list(reconstructions.keys())[0]
+ batch_size = reconstructions[random_key].shape[0]
+ num_mods = len(reconstructions.keys())
+
+ # collect errors per modality, while preserving shapes in @reconstructions
+ recons_errors = []
+ for k in reconstructions:
+ L2_loss = (reconstructions[k] - targets[k]).pow(2)
+ recons_errors.append(L2_loss)
+
+ # reduce errors across modalities and dimensions
+ if self.decoder_reconstruction_sum_across_elements:
+ # average across batch but sum across modalities and dimensions
+ loss = sum([x.sum() for x in recons_errors])
+ loss /= batch_size
+ else:
+ # compute mse loss in each modality and average across modalities
+ loss = sum([x.mean() for x in recons_errors])
+ loss /= num_mods
+ return loss
+
+ def forward(self, inputs, outputs, conditions=None, goals=None, freeze_encoder=False):
+ """
+ A full pass through the VAE network to construct KL and reconstruction
+ losses.
+
+ Args:
+ inputs (dict): a dictionary that maps input modalities to torch.Tensor
+ batches. These should correspond to the encoder-only modalities
+ (i.e. @self.encoder_only_shapes).
+
+ outputs (dict): a dictionary that maps output modalities to torch.Tensor
+ batches. These should correspond to the modalities used for
+ reconstruction (i.e. @self.output_shapes).
+
+ conditions (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to the modalities used for conditioning
+ in either the decoder or the prior (or both). Only for cVAEs.
+
+ goals (dict): a dictionary that maps modalities to torch.Tensor
+ batches. These should correspond to goal modalities. Only for cVAEs.
+
+ freeze_encoder (bool): if True, don't backprop into encoder by detaching
+ encoder outputs. Useful for doing staged VAE training.
+
+ Returns:
+ vae_outputs (dict): a dictionary that contains the following outputs.
+
+ encoder_params (dict): parameters for the posterior distribution
+ from the encoder forward pass
+
+ encoder_z (torch.Tensor): latents sampled from the encoder posterior
+
+ decoder_outputs (dict): reconstructions from the decoder
+
+ kl_loss (torch.Tensor): KL loss over the batch of data
+
+ reconstruction_loss (torch.Tensor): reconstruction loss over the batch of data
+ """
+
+ # In the comments below, X = inputs, Y = conditions, and we seek to learn P(X | Y).
+ # The decoder and prior only have knowledge about Y and try to reconstruct X.
+ # Notice that when Y is the empty set, this reduces to a normal VAE.
+
+ # mu, logvar <- Enc(X, Y)
+ posterior_params = self.encode(
+ inputs=inputs,
+ conditions=conditions,
+ goals=goals,
+ )
+
+ if freeze_encoder:
+ posterior_params = TensorUtils.detach(posterior_params)
+
+ # z ~ Enc(z | X, Y)
+ encoder_z = self.reparameterize(posterior_params)
+
+ # hat(X) = Dec(z, Y)
+ reconstructions = self.decode(
+ conditions=conditions,
+ goals=goals,
+ z=encoder_z,
+ )
+
+ # this will also train prior network z ~ Prior(z | Y)
+ kl_loss = self.kl_loss(
+ posterior_params=posterior_params,
+ encoder_z=encoder_z,
+ conditions=conditions,
+ goals=goals,
+ )
+
+ reconstruction_loss = self.reconstruction_loss(
+ reconstructions=reconstructions,
+ targets=outputs,
+ )
+
+ return {
+ "encoder_params" : posterior_params,
+ "encoder_z" : encoder_z,
+ "decoder_outputs" : reconstructions,
+ "kl_loss" : kl_loss,
+ "reconstruction_loss" : reconstruction_loss,
+ }
+
+ def set_gumbel_temperature(self, temperature):
+ """
+ Used by external algorithms to schedule Gumbel-Softmax temperature,
+ which is used during reparametrization at train-time. Should only
+ be used if @self.prior_use_categorical is True.
+ """
+ assert self.prior_use_categorical
+ self._gumbel_temperature = temperature
+
+ def get_gumbel_temperature(self):
+ """
+ Return current Gumbel-Softmax temperature. Should only be used if
+ @self.prior_use_categorical is True.
+ """
+ assert self.prior_use_categorical
+ return self._gumbel_temperature
diff --git a/phantom/submodules/phantom-robomimic/robomimic/models/value_nets.py b/phantom/submodules/phantom-robomimic/robomimic/models/value_nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..c98fa7e4e0f4185b2a11e5581158268aa9cda2cf
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/models/value_nets.py
@@ -0,0 +1,318 @@
+"""
+Contains torch Modules for value networks. These networks take an
+observation dictionary as input (and possibly additional conditioning,
+such as subgoal or goal dictionaries) and produce value or
+action-value estimates or distributions.
+"""
+import numpy as np
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributions as D
+
+import robomimic.utils.tensor_utils as TensorUtils
+from robomimic.models.obs_nets import MIMO_MLP
+from robomimic.models.distributions import DiscreteValueDistribution
+
+
+class ValueNetwork(MIMO_MLP):
+ """
+ A basic value network that predicts values from observations.
+ Can optionally be goal conditioned on future observations.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ mlp_layer_dims,
+ value_bounds=None,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps observation keys to
+ expected shapes for observations.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes.
+
+ value_bounds (tuple): a 2-tuple corresponding to the lowest and highest possible return
+ that the network should be possible of generating. The network will rescale outputs
+ using a tanh layer to lie within these bounds. If None, no tanh re-scaling is done.
+
+ goal_shapes (OrderedDict): a dictionary that maps observation keys to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-observation key information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+ self.value_bounds = value_bounds
+ if self.value_bounds is not None:
+ # convert [lb, ub] to a scale and offset for the tanh output, which is in [-1, 1]
+ self._value_scale = (float(self.value_bounds[1]) - float(self.value_bounds[0])) / 2.
+ self._value_offset = (float(self.value_bounds[1]) + float(self.value_bounds[0])) / 2.
+
+ assert isinstance(obs_shapes, OrderedDict)
+ self.obs_shapes = obs_shapes
+
+ # set up different observation groups for @MIMO_MLP
+ observation_group_shapes = OrderedDict()
+ observation_group_shapes["obs"] = OrderedDict(self.obs_shapes)
+
+ self._is_goal_conditioned = False
+ if goal_shapes is not None and len(goal_shapes) > 0:
+ assert isinstance(goal_shapes, OrderedDict)
+ self._is_goal_conditioned = True
+ self.goal_shapes = OrderedDict(goal_shapes)
+ observation_group_shapes["goal"] = OrderedDict(self.goal_shapes)
+ else:
+ self.goal_shapes = OrderedDict()
+
+ output_shapes = self._get_output_shapes()
+ super(ValueNetwork, self).__init__(
+ input_obs_group_shapes=observation_group_shapes,
+ output_shapes=output_shapes,
+ layer_dims=mlp_layer_dims,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def _get_output_shapes(self):
+ """
+ Allow subclasses to re-define outputs from @MIMO_MLP, since we won't
+ always directly predict values, but may instead predict the parameters
+ of a value distribution.
+ """
+ return OrderedDict(value=(1,))
+
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ return [1]
+
+ def forward(self, obs_dict, goal_dict=None):
+ """
+ Forward through value network, and then optionally use tanh scaling.
+ """
+ values = super(ValueNetwork, self).forward(obs=obs_dict, goal=goal_dict)["value"]
+ if self.value_bounds is not None:
+ values = self._value_offset + self._value_scale * torch.tanh(values)
+ return values
+
+ def _to_string(self):
+ return "value_bounds={}".format(self.value_bounds)
+
+
+class ActionValueNetwork(ValueNetwork):
+ """
+ A basic Q (action-value) network that predicts values from observations
+ and actions. Can optionally be goal conditioned on future observations.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ mlp_layer_dims,
+ value_bounds=None,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps observation keys to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes.
+
+ value_bounds (tuple): a 2-tuple corresponding to the lowest and highest possible return
+ that the network should be possible of generating. The network will rescale outputs
+ using a tanh layer to lie within these bounds. If None, no tanh re-scaling is done.
+
+ goal_shapes (OrderedDict): a dictionary that maps observation keys to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-observation key information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+
+ # add in action as a modality
+ new_obs_shapes = OrderedDict(obs_shapes)
+ new_obs_shapes["action"] = (ac_dim,)
+ self.ac_dim = ac_dim
+
+ # pass to super class to instantiate network
+ super(ActionValueNetwork, self).__init__(
+ obs_shapes=new_obs_shapes,
+ mlp_layer_dims=mlp_layer_dims,
+ value_bounds=value_bounds,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def forward(self, obs_dict, acts, goal_dict=None):
+ """
+ Modify forward from super class to include actions in inputs.
+ """
+ inputs = dict(obs_dict)
+ inputs["action"] = acts
+ return super(ActionValueNetwork, self).forward(inputs, goal_dict)
+
+ def _to_string(self):
+ return "action_dim={}\nvalue_bounds={}".format(self.ac_dim, self.value_bounds)
+
+
+class DistributionalActionValueNetwork(ActionValueNetwork):
+ """
+ Distributional Q (action-value) network that outputs a categorical distribution over
+ a discrete grid of value atoms. See https://arxiv.org/pdf/1707.06887.pdf for
+ more details.
+ """
+ def __init__(
+ self,
+ obs_shapes,
+ ac_dim,
+ mlp_layer_dims,
+ value_bounds,
+ num_atoms,
+ goal_shapes=None,
+ encoder_kwargs=None,
+ ):
+ """
+ Args:
+ obs_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for observations.
+
+ ac_dim (int): dimension of action space.
+
+ mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes.
+
+ value_bounds (tuple): a 2-tuple corresponding to the lowest and highest possible return
+ that the network should be possible of generating. This defines the support
+ of the value distribution.
+
+ num_atoms (int): number of value atoms to use for the categorical distribution - which
+ is the representation of the value distribution.
+
+ goal_shapes (OrderedDict): a dictionary that maps modality to
+ expected shapes for goal observations.
+
+ encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
+ be nested dictionary containing relevant per-modality information for encoder networks.
+ Should be of form:
+
+ obs_modality1: dict
+ feature_dimension: int
+ core_class: str
+ core_kwargs: dict
+ ...
+ ...
+ obs_randomizer_class: str
+ obs_randomizer_kwargs: dict
+ ...
+ ...
+ obs_modality2: dict
+ ...
+ """
+
+ # parameters specific to DistributionalActionValueNetwork
+ self.num_atoms = num_atoms
+ self._atoms = np.linspace(value_bounds[0], value_bounds[1], num_atoms)
+
+ # pass to super class to instantiate network
+ super(DistributionalActionValueNetwork, self).__init__(
+ obs_shapes=obs_shapes,
+ ac_dim=ac_dim,
+ mlp_layer_dims=mlp_layer_dims,
+ value_bounds=value_bounds,
+ goal_shapes=goal_shapes,
+ encoder_kwargs=encoder_kwargs,
+ )
+
+ def _get_output_shapes(self):
+ """
+ Network outputs log probabilities for categorical distribution over discrete value grid.
+ """
+ return OrderedDict(log_probs=(self.num_atoms,))
+
+ def forward_train(self, obs_dict, acts, goal_dict=None):
+ """
+ Return full critic categorical distribution.
+
+ Args:
+ obs_dict (dict): batch of observations
+ acts (torch.Tensor): batch of actions
+ goal_dict (dict): if not None, batch of goal observations
+
+ Returns:
+ value_distribution (DiscreteValueDistribution instance)
+ """
+
+ # add in actions
+ inputs = dict(obs_dict)
+ inputs["action"] = acts
+
+ # network returns unnormalized log probabilities (logits) for each of the value atoms
+ logits = MIMO_MLP.forward(self, obs=inputs, goal=goal_dict)["log_probs"]
+
+ # turn these logits into a categorical distribution over the value atoms.
+ # (unsqueeze to make sure atoms are compatible with batch operations)
+ value_atoms = torch.Tensor(self._atoms).unsqueeze(0).to(logits.device)
+ return DiscreteValueDistribution(values=value_atoms, logits=logits)
+
+ def forward(self, obs_dict, acts, goal_dict=None):
+ """
+ Return mean of critic categorical distribution. Useful for obtaining
+ point estimates of critic values.
+
+ Args:
+ obs_dict (dict): batch of observations
+ acts (torch.Tensor): batch of actions
+ goal_dict (dict): if not None, batch of goal observations
+
+ Returns:
+ mean_value (torch.Tensor): expectation of value distribution
+ """
+ vd = self.forward_train(obs_dict=obs_dict, acts=acts, goal_dict=goal_dict)
+ return vd.mean()
+
+ def _to_string(self):
+ return "action_dim={}\nvalue_bounds={}\nnum_atoms={}".format(self.ac_dim, self.value_bounds, self.num_atoms)
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/config_gen/bc_xfmr_gen.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/config_gen/bc_xfmr_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6fe0db1e630658e6fb4a152dbb4a8035efa8a8f
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/config_gen/bc_xfmr_gen.py
@@ -0,0 +1,156 @@
+from robomimic.scripts.config_gen.helper import *
+
+def make_generator_helper(args):
+ algo_name_short = "bc_xfmr"
+
+ generator = get_generator(
+ algo_name="diffusion_policy",
+ config_file=os.path.join(base_path, 'robomimic/exps/templates/diffusion_policy.json'),
+ args=args,
+ algo_name_short=algo_name_short,
+ pt=True,
+ )
+ if args.ckpt_mode is None:
+ args.ckpt_mode = "off"
+
+ generator.add_param(
+ key="train.num_data_workers",
+ name="",
+ group=-1,
+ values=[4],
+ )
+ generator.add_param(
+ key="experiment.save.every_n_epochs",
+ name="",
+ group=-1,
+ values=[
+ 100
+ ],
+ )
+
+ # run rollouts at epoch 0 only
+ generator.add_param(
+ key="experiment.rollout.warmstart",
+ name="",
+ group=-1,
+ values=[
+ -1,
+ ],
+ )
+ generator.add_param(
+ key="train.num_epochs",
+ name="",
+ group=-1,
+ values=[40],
+ )
+ generator.add_param(
+ key="experiment.rollout.rate",
+ name="",
+ group=-1,
+ values=[10],
+ )
+
+ if args.env == "r2d2":
+ generator.add_param(
+ key="train.data",
+ name="ds",
+ group=2,
+ values=[
+ # [{"path": p} for p in scan_datasets("~/code/r2d2/data/success/2023-05-23_t2c-cans", postfix="trajectory_im84.h5")],
+ [{"path": p} for p in scan_datasets("/home/cchi/local/data/r2d2/pen/success/2023-02-28", postfix="trajectory_im128.h5")],
+ ],
+ value_names=[
+ "pnp-t2c-cans-84",
+ # "pnp-t2c-cans-128",
+ ],
+ )
+ generator.add_param(
+ key="observation.encoder.rgb.obs_randomizer_kwargs.crop_height",
+ name="",
+ group=2,
+ values=[
+ 76,
+ # 116
+ ],
+ )
+ generator.add_param(
+ key="observation.encoder.rgb.obs_randomizer_kwargs.crop_width",
+ name="",
+ group=2,
+ values=[
+ 76,
+ # 116
+ ],
+ )
+ elif args.env == "square":
+ generator.add_param(
+ key="train.data",
+ name="ds",
+ group=2,
+ values=[
+ [
+ {"path": "~/datasets/square/ph/image_v141.hdf5"},
+ {"path": "~/datasets/square/ph/image_v141.hdf5"},
+ ],
+ ],
+ value_names=[
+ "square",
+ ],
+ )
+ else:
+ raise ValueError
+
+ if "experiment.ckpt_path" in generator.parameters:
+ generator.add_param(
+ key="algo.optim_params.policy.learning_rate.initial",
+ name="lrinit",
+ group=110,
+ values=[
+ 1e-5,
+ ],
+ hidename=True,
+ )
+ generator.add_param(
+ key="algo.optim_params.policy.learning_rate.lr_scheduler_type",
+ name="lrsch",
+ group=111,
+ values=[
+ # "linear",
+ None,
+ ],
+ value_names=[
+ "none"
+ ],
+ hidename=True,
+ )
+
+ generator.add_param(
+ key="train.output_dir",
+ name="",
+ group=-1,
+ values=[
+ "/home/cchi/dev/robomimic_r2d2/datasets/experiment_results/{env}/{mod}/{algo_name_short}".format(
+ env=args.env,
+ mod=args.mod,
+ algo_name_short=algo_name_short,
+ )
+ ],
+ )
+
+ generator.add_param(
+ key="experiment.rollout.enabled",
+ name="",
+ group=-1,
+ values=[
+ True
+ ],
+ hidename=False,
+ )
+
+ return generator
+
+if __name__ == "__main__":
+ parser = get_argparser()
+
+ args = parser.parse_args()
+ make_generator(args, make_generator_helper)
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/config_gen/diffusion_gen.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/config_gen/diffusion_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..990a163ae72aa913fafcb87e2372ae7ca5d1fda4
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/config_gen/diffusion_gen.py
@@ -0,0 +1,147 @@
+from robomimic.scripts.config_gen.helper import *
+
+def make_generator_helper(args):
+ algo_name_short = "diffusion_policy"
+
+ generator = get_generator(
+ algo_name="diffusion_policy",
+ config_file=os.path.join(base_path, 'robomimic/exps/templates/diffusion_policy.json'),
+ args=args,
+ algo_name_short=algo_name_short,
+ pt=True,
+ )
+ if args.ckpt_mode is None:
+ args.ckpt_mode = "off"
+
+ generator.add_param(
+ key="train.num_data_workers",
+ name="",
+ group=-1,
+ values=[8],
+ )
+
+ generator.add_param(
+ key="train.num_epochs",
+ name="",
+ group=-1,
+ values=[1000],
+ )
+
+ # use ddim by default
+ generator.add_param(
+ key="algo.ddim.enabled",
+ name="ddim",
+ group=1001,
+ values=[
+ True,
+ # False,
+ ],
+ )
+ generator.add_param(
+ key="algo.ddpm.enabled",
+ name="ddpm",
+ group=1001,
+ values=[
+ False,
+ # True,
+ ],
+ hidename=True,
+ )
+
+ if args.env == "r2d2":
+ generator.add_param(
+ key="train.data",
+ name="ds",
+ group=2,
+ values=[
+ [{"path": p} for p in scan_datasets("~/Downloads/example_pen_in_cup", postfix="trajectory_im128.h5")],
+ ],
+ value_names=[
+ "pen-in-cup",
+ ],
+ )
+ generator.add_param(
+ key="train.action_keys",
+ name="ac_keys",
+ group=-1,
+ values=[
+ [
+ "action/abs_pos",
+ "action/abs_rot_6d",
+ "action/gripper_velocity",
+ ],
+ ],
+ value_names=[
+ "abs",
+ ],
+ )
+ elif args.env == "square":
+ generator.add_param(
+ key="train.data",
+ name="ds",
+ group=2,
+ values=[
+ [
+ # TODO: point to the hdf5 file
+ # {"path": "/home/cchi/dev/robomimic_r2d2/datasets/square/ph/image_abs.hdf5"},
+ # {"path": "~/datasets/square/ph/image_v141.hdf5"},
+ # {"path": "~/datasets/square/ph/image.hdf5"},
+ {"path": "~/datasets/square/ph/square_ph_abs_tmp.hdf5"}, # replace with your own path
+ ],
+ ],
+ value_names=[
+ "square",
+ ],
+ )
+
+ # update env config to use absolute action control
+ generator.add_param(
+ key="experiment.env_meta_update_dict",
+ name="",
+ group=-1,
+ values=[
+ {"env_kwargs": {"controller_configs": {"control_delta": False}}}
+ ],
+ )
+
+ generator.add_param(
+ key="train.action_keys",
+ name="ac_keys",
+ group=-1,
+ values=[
+ [
+ "action_dict/abs_pos",
+ "action_dict/abs_rot_6d",
+ "action_dict/gripper",
+ # "actions",
+ ],
+ ],
+ value_names=[
+ "abs",
+ ],
+ )
+
+
+ else:
+ raise ValueError
+
+ generator.add_param(
+ key="train.output_dir",
+ name="",
+ group=-1,
+ values=[
+ "~/expdata/{env}/{mod}/{algo_name_short}".format(
+ env=args.env,
+ mod=args.mod,
+ algo_name_short=algo_name_short,
+ )
+ ],
+ )
+
+ return generator
+
+if __name__ == "__main__":
+ parser = get_argparser()
+
+ args = parser.parse_args()
+ make_generator(args, make_generator_helper)
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/config_gen/helper.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/config_gen/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca29a37843c2b208f5ae7a3282bc9f41f426faa4
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/config_gen/helper.py
@@ -0,0 +1,709 @@
+import argparse
+import os
+import time
+import datetime
+
+import robomimic
+import robomimic.utils.hyperparam_utils as HyperparamUtils
+
+base_path = os.path.abspath(os.path.join(os.path.dirname(robomimic.__file__), os.pardir))
+
+def scan_datasets(folder, postfix=".h5"):
+ dataset_paths = []
+ for root, dirs, files in os.walk(os.path.expanduser(folder)):
+ for f in files:
+ if f.endswith(postfix):
+ dataset_paths.append(os.path.join(root, f))
+ return dataset_paths
+
+
+def get_generator(algo_name, config_file, args, algo_name_short=None, pt=False):
+ if args.wandb_proj_name is None:
+ strings = [
+ algo_name_short if (algo_name_short is not None) else algo_name,
+ args.name,
+ args.env,
+ args.mod,
+ ]
+ args.wandb_proj_name = '_'.join([str(s) for s in strings if s is not None])
+
+ if args.script is not None:
+ generated_config_dir = os.path.join(os.path.dirname(args.script), "json")
+ else:
+ curr_time = datetime.datetime.fromtimestamp(time.time()).strftime('%m-%d-%y-%H-%M-%S')
+ generated_config_dir=os.path.join(
+ '~/', 'tmp/autogen_configs/ril', algo_name, args.env, args.mod, args.name, curr_time, "json",
+ )
+
+ generator = HyperparamUtils.ConfigGenerator(
+ base_config_file=config_file,
+ generated_config_dir=generated_config_dir,
+ wandb_proj_name=args.wandb_proj_name,
+ script_file=args.script,
+ )
+
+ args.algo_name = algo_name
+ args.pt = pt
+
+ return generator
+
+
+def set_env_settings(generator, args):
+ if args.env in ["r2d2"]:
+ assert args.mod == "im"
+ generator.add_param(
+ key="experiment.rollout.enabled",
+ name="",
+ group=-1,
+ values=[
+ False
+ ],
+ )
+ generator.add_param(
+ key="experiment.save.every_n_epochs",
+ name="",
+ group=-1,
+ values=[50],
+ )
+ if "observation.modalities.obs.low_dim" not in generator.parameters:
+ generator.add_param(
+ key="observation.modalities.obs.low_dim",
+ name="",
+ group=-1,
+ values=[
+ ["robot_state/cartesian_position", "robot_state/gripper_position"]
+ ],
+ )
+ if "observation.modalities.obs.rgb" not in generator.parameters:
+ generator.add_param(
+ key="observation.modalities.obs.rgb",
+ name="",
+ group=-1,
+ values=[
+ [
+ "camera/image/hand_camera_image",
+ # "camera/image/varied_camera_1_image", "camera/image/varied_camera_2_image" # uncomment to use all 3 cameras
+ ]
+ ],
+ )
+ if "observation.encoder.rgb.obs_randomizer_kwargs.crop_height" not in generator.parameters:
+ generator.add_param(
+ key="observation.encoder.rgb.obs_randomizer_kwargs.crop_height",
+ name="",
+ group=-1,
+ values=[
+ 116
+ ],
+ )
+ generator.add_param(
+ key="observation.encoder.rgb.obs_randomizer_kwargs.crop_width",
+ name="",
+ group=-1,
+ values=[
+ 116
+ ],
+ )
+ generator.add_param(
+ key="train.data_format",
+ name="",
+ group=-1,
+ values=[
+ "r2d2"
+ ],
+ )
+ # specify action keys in your _gen.py
+ # here, we list how each action key should be treated (normalized etc)
+ generator.add_param(
+ key="train.action_config",
+ name="",
+ group=-1,
+ values=[
+ {
+ "action/cartesian_position":{
+ "normalization": "min_max",
+ },
+ "action/abs_pos":{
+ "normalization": "min_max",
+ },
+ "action/abs_rot_6d":{
+ "normalization": "min_max",
+ "format": "rot_6d",
+ },
+ "action/abs_rot_axis_angle":{
+ "normalization": "min_max",
+ "format": "rot_axis_angle",
+ },
+ "action/gripper_position":{
+ "normalization": "min_max",
+ },
+ "action/cartesian_velocity":{
+ "normalization": None,
+ },
+ "action/gripper_velocity":{
+ "normalization": None,
+ },
+ }
+ ],
+ )
+ generator.add_param(
+ key="train.dataset_keys",
+ name="",
+ group=-1,
+ values=[[]],
+ )
+ elif args.env in ['square', 'lift', 'place_close']:
+ # # set videos off
+ # args.no_video = True
+
+ generator.add_param(
+ key="train.action_config",
+ name="",
+ group=-1,
+ values=[
+ {
+ "actions":{
+ "normalization": None,
+ },
+ "action_dict/abs_pos": {
+ "normalization": "min_max"
+ },
+ "action_dict/abs_rot_axis_angle": {
+ "normalization": "min_max",
+ "format": "rot_axis_angle"
+ },
+ "action_dict/abs_rot_6d": {
+ "normalization": None,
+ "format": "rot_6d"
+ },
+ "action_dict/rel_pos": {
+ "normalization": None,
+ },
+ "action_dict/rel_rot_axis_angle": {
+ "normalization": None,
+ "format": "rot_axis_angle"
+ },
+ "action_dict/rel_rot_6d": {
+ "normalization": None,
+ "format": "rot_6d"
+ },
+ "action_dict/gripper": {
+ "normalization": None,
+ }
+ }
+ ],
+ )
+
+ if args.mod == 'im':
+ generator.add_param(
+ key="observation.modalities.obs.low_dim",
+ name="",
+ group=-1,
+ values=[
+ ["robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos"]
+ ],
+ )
+ generator.add_param(
+ key="observation.modalities.obs.rgb",
+ name="",
+ group=-1,
+ values=[
+ ["agentview_image",
+ "robot0_eye_in_hand_image"]
+ ],
+ )
+ else:
+ generator.add_param(
+ key="observation.modalities.obs.low_dim",
+ name="",
+ group=-1,
+ values=[
+ ["robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"]
+ ],
+ )
+ elif args.env == 'transport':
+ # set videos off
+ args.no_video = True
+
+ # TODO: fix 2 robot case
+ generator.add_param(
+ key="train.action_config",
+ name="",
+ group=-1,
+ values=[
+ {
+ "actions":{
+ "normalization": None,
+ },
+ "action_dict/abs_pos": {
+ "normalization": "min_max"
+ },
+ "action_dict/abs_rot_axis_angle": {
+ "normalization": "min_max",
+ "format": "rot_axis_angle"
+ },
+ "action_dict/abs_rot_6d": {
+ "normalization": None,
+ "format": "rot_6d"
+ },
+ "action_dict/rel_pos": {
+ "normalization": None,
+ },
+ "action_dict/rel_rot_axis_angle": {
+ "normalization": None,
+ "format": "rot_axis_angle"
+ },
+ "action_dict/rel_rot_6d": {
+ "normalization": None,
+ "format": "rot_6d"
+ },
+ "action_dict/gripper": {
+ "normalization": None,
+ }
+ }
+ ],
+ )
+
+ if args.mod == 'im':
+ generator.add_param(
+ key="observation.modalities.obs.low_dim",
+ name="",
+ group=-1,
+ values=[
+ ["robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "robot1_eef_pos",
+ "robot1_eef_quat",
+ "robot1_gripper_qpos"]
+ ],
+ )
+ generator.add_param(
+ key="observation.modalities.obs.rgb",
+ name="",
+ group=-1,
+ values=[
+ ["shouldercamera0_image",
+ "robot0_eye_in_hand_image",
+ "shouldercamera1_image",
+ "robot1_eye_in_hand_image"]
+ ],
+ )
+ else:
+ generator.add_param(
+ key="observation.modalities.obs.low_dim",
+ name="",
+ group=-1,
+ values=[
+ ["robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "robot1_eef_pos",
+ "robot1_eef_quat",
+ "robot1_gripper_qpos",
+ "object"]
+ ],
+ )
+
+ generator.add_param(
+ key="experiment.rollout.horizon",
+ name="",
+ group=-1,
+ values=[700],
+ )
+ elif args.env == 'tool_hang':
+ # set videos off
+ args.no_video = True
+
+ generator.add_param(
+ key="train.action_config",
+ name="",
+ group=-1,
+ values=[
+ {
+ "actions":{
+ "normalization": None,
+ },
+ "action_dict/abs_pos": {
+ "normalization": "min_max"
+ },
+ "action_dict/abs_rot_axis_angle": {
+ "normalization": "min_max",
+ "format": "rot_axis_angle"
+ },
+ "action_dict/abs_rot_6d": {
+ "normalization": None,
+ "format": "rot_6d"
+ },
+ "action_dict/rel_pos": {
+ "normalization": None,
+ },
+ "action_dict/rel_rot_axis_angle": {
+ "normalization": None,
+ "format": "rot_axis_angle"
+ },
+ "action_dict/rel_rot_6d": {
+ "normalization": None,
+ "format": "rot_6d"
+ },
+ "action_dict/gripper": {
+ "normalization": None,
+ }
+ }
+ ],
+ )
+
+ if args.mod == 'im':
+ generator.add_param(
+ key="observation.modalities.obs.low_dim",
+ name="",
+ group=-1,
+ values=[
+ ["robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos"]
+ ],
+ )
+ generator.add_param(
+ key="observation.modalities.obs.rgb",
+ name="",
+ group=-1,
+ values=[
+ ["sideview_image",
+ "robot0_eye_in_hand_image"]
+ ],
+ )
+ generator.add_param(
+ key="observation.encoder.rgb.obs_randomizer_kwargs.crop_height",
+ name="",
+ group=-1,
+ values=[
+ 216
+ ],
+ )
+ generator.add_param(
+ key="observation.encoder.rgb.obs_randomizer_kwargs.crop_width",
+ name="",
+ group=-1,
+ values=[
+ 216
+ ],
+ )
+ generator.add_param(
+ key="observation.encoder.rgb2.obs_randomizer_kwargs.crop_height",
+ name="",
+ group=-1,
+ values=[
+ 216
+ ],
+ )
+ generator.add_param(
+ key="observation.encoder.rgb2.obs_randomizer_kwargs.crop_width",
+ name="",
+ group=-1,
+ values=[
+ 216
+ ],
+ )
+ else:
+ generator.add_param(
+ key="observation.modalities.obs.low_dim",
+ name="",
+ group=-1,
+ values=[
+ ["robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object"]
+ ],
+ )
+
+ generator.add_param(
+ key="experiment.rollout.horizon",
+ name="",
+ group=-1,
+ values=[700],
+ )
+ else:
+ raise ValueError
+
+
+def set_mod_settings(generator, args):
+ if args.mod == 'ld':
+ if "experiment.save.epochs" not in generator.parameters:
+ generator.add_param(
+ key="experiment.save.epochs",
+ name="",
+ group=-1,
+ values=[
+ [2000]
+ ],
+ )
+ elif args.mod == 'im':
+ if "experiment.save.every_n_epochs" not in generator.parameters:
+ generator.add_param(
+ key="experiment.save.every_n_epochs",
+ name="",
+ group=-1,
+ values=[20],
+ )
+
+ generator.add_param(
+ key="experiment.epoch_every_n_steps",
+ name="",
+ group=-1,
+ values=[500],
+ )
+ if "train.num_data_workers" not in generator.parameters:
+ generator.add_param(
+ key="train.num_data_workers",
+ name="",
+ group=-1,
+ values=[4],
+ )
+ generator.add_param(
+ key="train.hdf5_cache_mode",
+ name="",
+ group=-1,
+ values=["low_dim"],
+ )
+ if "train.batch_size" not in generator.parameters:
+ generator.add_param(
+ key="train.batch_size",
+ name="",
+ group=-1,
+ values=[16],
+ )
+ if "train.num_epochs" not in generator.parameters:
+ generator.add_param(
+ key="train.num_epochs",
+ name="",
+ group=-1,
+ values=[600],
+ )
+ if "experiment.rollout.rate" not in generator.parameters:
+ generator.add_param(
+ key="experiment.rollout.rate",
+ name="",
+ group=-1,
+ values=[20],
+ )
+
+
+def set_debug_mode(generator, args):
+ if not args.debug:
+ return
+
+ generator.add_param(
+ key="experiment.rollout.n",
+ name="",
+ group=-1,
+ values=[2],
+ value_names=[""],
+ )
+ generator.add_param(
+ key="experiment.rollout.horizon",
+ name="",
+ group=-1,
+ values=[30],
+ value_names=[""],
+ )
+ generator.add_param(
+ key="experiment.rollout.rate",
+ name="",
+ group=-1,
+ values=[2],
+ value_names=[""],
+ )
+ generator.add_param(
+ key="experiment.epoch_every_n_steps",
+ name="",
+ group=-1,
+ values=[2],
+ value_names=[""],
+ )
+ generator.add_param(
+ key="experiment.save.every_n_epochs",
+ name="",
+ group=-1,
+ values=[2],
+ value_names=[""],
+ )
+ generator.add_param(
+ key="experiment.validation_epoch_every_n_steps",
+ name="",
+ group=-1,
+ values=[2],
+ value_names=[""],
+ )
+ generator.add_param(
+ key="train.num_epochs",
+ name="",
+ group=-1,
+ values=[2],
+ value_names=[""],
+ )
+ if args.name is None:
+ generator.add_param(
+ key="experiment.name",
+ name="",
+ group=-1,
+ values=["debug"],
+ value_names=[""],
+ )
+ generator.add_param(
+ key="experiment.save.enabled",
+ name="",
+ group=-1,
+ values=[False],
+ value_names=[""],
+ )
+ generator.add_param(
+ key="train.hdf5_cache_mode",
+ name="",
+ group=-1,
+ values=["low_dim"],
+ value_names=[""],
+ )
+ generator.add_param(
+ key="train.num_data_workers",
+ name="",
+ group=-1,
+ values=[3],
+ )
+
+
+def get_argparser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--name",
+ type=str,
+ )
+
+ parser.add_argument(
+ "--env",
+ type=str,
+ default='r2d2',
+ )
+
+ parser.add_argument(
+ '--mod',
+ type=str,
+ choices=['ld', 'im'],
+ default='im',
+ )
+
+ parser.add_argument(
+ "--ckpt_mode",
+ type=str,
+ choices=["off", "all", "best_only"],
+ default=None,
+ )
+
+ parser.add_argument(
+ "--script",
+ type=str,
+ default=None
+ )
+
+ parser.add_argument(
+ "--wandb_proj_name",
+ type=str,
+ default=None
+ )
+
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ )
+
+ parser.add_argument(
+ '--no_video',
+ action='store_true'
+ )
+
+ parser.add_argument(
+ "--tmplog",
+ action="store_true",
+ )
+
+ parser.add_argument(
+ "--nr",
+ type=int,
+ default=-1
+ )
+
+ parser.add_argument(
+ "--no_wandb",
+ action="store_true",
+ )
+
+ parser.add_argument(
+ "--n_seeds",
+ type=int,
+ default=None
+ )
+
+ parser.add_argument(
+ "--num_cmd_groups",
+ type=int,
+ default=None
+ )
+
+ return parser
+
+
+def make_generator(args, make_generator_helper):
+ if args.tmplog or args.debug and args.name is None:
+ args.name = "debug"
+ else:
+ time_str = datetime.datetime.fromtimestamp(time.time()).strftime('%m-%d-')
+ args.name = time_str + str(args.name)
+
+ if args.debug or args.tmplog:
+ args.no_wandb = True
+
+ if args.wandb_proj_name is not None:
+ # prepend data to wandb name
+ # time_str = datetime.datetime.fromtimestamp(time.time()).strftime('%m-%d-')
+ # args.wandb_proj_name = time_str + args.wandb_proj_name
+ pass
+
+ if (args.debug or args.tmplog) and (args.wandb_proj_name is None):
+ args.wandb_proj_name = 'debug'
+
+ if not args.debug:
+ assert args.name is not None
+
+ # make config generator
+ generator = make_generator_helper(args)
+
+ if args.ckpt_mode is None:
+ if args.pt:
+ args.ckpt_mode = "all"
+ else:
+ args.ckpt_mode = "best_only"
+
+ set_env_settings(generator, args)
+ set_mod_settings(generator, args)
+
+ # set the debug settings last, to override previous setting changes
+ set_debug_mode(generator, args)
+
+ """ misc settings """
+ generator.add_param(
+ key="experiment.validate",
+ name="",
+ group=-1,
+ values=[
+ False,
+ ],
+ )
+
+ # generate jsons and script
+ generator.generate()
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_d4rl.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_d4rl.py
new file mode 100644
index 0000000000000000000000000000000000000000..99fc1d93c53709f6f149d43457c582cea77e853c
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_d4rl.py
@@ -0,0 +1,143 @@
+"""
+Helper script to convert D4RL data into an hdf5 compatible with this repository.
+Takes a folder path and a D4RL env name. This script downloads the corresponding
+raw D4RL dataset into a "d4rl" subfolder, and then makes a converted dataset
+in the "d4rl/converted" subfolder.
+
+This script has been tested on the follwing commits:
+
+ https://github.com/rail-berkeley/d4rl/tree/9b68f31bab6a8546edfb28ff0bd9d5916c62fd1f
+ https://github.com/rail-berkeley/d4rl/tree/26adf732efafdad864b3df2287e7b778ee4f7f63
+
+Args:
+ env (str): d4rl env name, which specifies the dataset to download and convert
+ folder (str): specify folder to download raw d4rl datasets and converted d4rl datasets to.
+ A `d4rl` subfolder will be created in this folder with the raw d4rl dataset, and
+ a `d4rl/converted` subfolder will be created in this folder with the converted
+ datasets (if they do not already exist). Defaults to the datasets folder at
+ the top-level of the repository.
+
+Example usage:
+
+ # downloads to default path at robomimic/datasets/d4rl
+ python convert_d4rl.py --env walker2d-medium-expert-v2
+
+ # download to custom path
+ python convert_d4rl.py --env walker2d-medium-expert-v2 --folder /path/to/folder
+"""
+
+import os
+import h5py
+import json
+import argparse
+import numpy as np
+
+import gym
+import d4rl
+import robomimic
+from robomimic.envs.env_gym import EnvGym
+from robomimic.utils.log_utils import custom_tqdm
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--env",
+ type=str,
+ help="d4rl env name, which specifies the dataset to download and convert",
+ )
+ parser.add_argument(
+ "--folder",
+ type=str,
+ default=None,
+ help="specify folder to download raw d4rl datasets and converted d4rl datasets to.\
+ A `d4rl` subfolder will be created in this folder with the raw d4rl dataset, and\
+ a `d4rl/converted` subfolder will be created in this folder with the converted\
+ datasets (if they do not already exist). Defaults to the datasets folder at\
+ the top-level of the repository.",
+ )
+ args = parser.parse_args()
+
+ base_folder = args.folder
+ if base_folder is None:
+ base_folder = os.path.join(robomimic.__path__[0], "../datasets")
+ base_folder = os.path.join(base_folder, "d4rl")
+
+ # get dataset
+ d4rl.set_dataset_path(base_folder)
+ env = gym.make(args.env)
+ ds = env.env.get_dataset()
+ env.close()
+
+ # env
+ env = EnvGym(args.env)
+
+ # output file
+ write_folder = os.path.join(base_folder, "converted")
+ if not os.path.exists(write_folder):
+ os.makedirs(write_folder)
+ output_path = os.path.join(base_folder, "converted", "{}.hdf5".format(args.env.replace("-", "_")))
+ f_sars = h5py.File(output_path, "w")
+ f_sars_grp = f_sars.create_group("data")
+
+ # code to split D4RL data into trajectories
+ # (modified from https://github.com/aviralkumar2907/d4rl_evaluations/blob/bear_intergrate/bear/examples/bear_hdf5_d4rl.py#L18)
+ all_obs = ds['observations']
+ all_act = ds['actions']
+ N = all_obs.shape[0]
+
+ obs = all_obs[:N-1]
+ actions = all_act[:N-1]
+ next_obs = all_obs[1:]
+ rewards = np.squeeze(ds['rewards'][:N-1])
+ dones = np.squeeze(ds['terminals'][:N-1]).astype(np.int32)
+
+ assert 'timeouts' in ds
+ timeouts = ds['timeouts'][:]
+
+ ctr = 0
+ total_samples = 0
+ num_traj = 0
+ traj = dict(obs=[], next_obs=[], actions=[], rewards=[], dones=[])
+
+ print("\nConverting hdf5...")
+ for idx in custom_tqdm(range(obs.shape[0])):
+
+ # add transition
+ traj["obs"].append(obs[idx])
+ traj["actions"].append(actions[idx])
+ traj["rewards"].append(rewards[idx])
+ traj["next_obs"].append(next_obs[idx])
+ traj["dones"].append(dones[idx])
+ ctr += 1
+
+ # if hit timeout or done is True, end the current trajectory and start a new trajectory
+ if timeouts[idx] or dones[idx]:
+
+ # replace next obs with copy of current obs for final timestep, and make sure done is true
+ traj["next_obs"][-1] = np.array(obs[idx])
+ traj["dones"][-1] = 1
+
+ # store trajectory
+ ep_data_grp = f_sars_grp.create_group("demo_{}".format(num_traj))
+ ep_data_grp.create_dataset("obs/flat", data=np.array(traj["obs"]))
+ ep_data_grp.create_dataset("next_obs/flat", data=np.array(traj["next_obs"]))
+ ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]))
+ ep_data_grp.create_dataset("rewards", data=np.array(traj["rewards"]))
+ ep_data_grp.create_dataset("dones", data=np.array(traj["dones"]))
+ ep_data_grp.attrs["num_samples"] = len(traj["actions"])
+ total_samples += len(traj["actions"])
+ num_traj += 1
+
+ # reset
+ ctr = 0
+ traj = dict(obs=[], next_obs=[], actions=[], rewards=[], dones=[])
+
+ print("\nExcluding {} samples at end of file due to no trajectory truncation.".format(len(traj["actions"])))
+ print("Wrote {} trajectories to new converted hdf5 at {}\n".format(num_traj, output_path))
+
+ # metadata
+ f_sars_grp.attrs["total"] = total_samples
+ f_sars_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4)
+
+ f_sars.close()
+
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_r2d2.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_r2d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..016b9a9d8dbfbfaab5ddcfff2d5b52f677b9e8e6
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_r2d2.py
@@ -0,0 +1,168 @@
+"""
+Add image information to existing r2d2 hdf5 file
+"""
+import h5py
+import os
+import numpy as np
+import glob
+from tqdm import tqdm
+import argparse
+import shutil
+import torch
+import pytorch3d.transforms as pt
+
+from r2d2.camera_utils.wrappers.recorded_multi_camera_wrapper import RecordedMultiCameraWrapper
+from r2d2.trajectory_utils.trajectory_reader import TrajectoryReader
+from r2d2.camera_utils.info import camera_type_to_string_dict
+
+def convert_dataset(path, args):
+ recording_folderpath = os.path.join(os.path.dirname(path), "recordings", "MP4")
+ camera_kwargs = dict(
+ hand_camera=dict(image=True, concatenate_images=False, resolution=(args.imsize, args.imsize), resize_func="cv2"),
+ varied_camera=dict(image=True, concatenate_images=False, resolution=(args.imsize, args.imsize), resize_func="cv2"),
+ )
+ camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
+
+ output_path = os.path.join(os.path.dirname(path), "trajectory_im{}.h5".format(args.imsize))
+ if os.path.exists(output_path):
+ # dataset already exists, skip
+ f = h5py.File(output_path)
+ if "observation/camera/image/hand_camera_image" in f.keys():
+ return
+ f.close()
+
+ shutil.copyfile(path, output_path)
+ f = h5py.File(output_path, "a")
+
+ demo_len = f["action"]["cartesian_position"].shape[0]
+
+ if "camera" not in f["observation"]:
+ f["observation"].create_group("camera").create_group("image")
+ image_grp = f["observation/camera/image"]
+
+ """
+ Extract camera type and keys. Examples of what they should look like:
+ camera_type_dict = {
+ '17225336': 'hand_camera',
+ '24013089': 'varied_camera',
+ '25047636': 'varied_camera'
+ }
+ CAM_NAME_TO_KEY_MAPPING = {
+ "hand_camera_image": "17225336_left",
+ "varied_camera_left_image": "25047636_right",
+ "varied_camera_right_image": "24013089_left"
+ }
+ """
+
+ CAM_ID_TO_TYPE = {}
+ for k in f["observation"]["camera_type"]:
+ CAM_ID_TO_TYPE[k] = camera_type_to_string_dict[f["observation"]["camera_type"][k][0]]
+
+ CAM_NAME_TO_KEY_MAPPING = {}
+ for (cam_id, cam_type) in CAM_ID_TO_TYPE.items():
+ if cam_type == "hand_camera":
+ cam_name = "hand_camera_image"
+ cam_key = "{}_left".format(cam_id)
+ elif cam_type == "varied_camera":
+ cam_name = "varied_camera_1_image" if "varied_camera_1_image" not in CAM_NAME_TO_KEY_MAPPING else "varied_camera_2_image"
+ cam_key = "{}_left".format(cam_id)
+ else:
+ raise NotImplementedError
+
+ CAM_NAME_TO_KEY_MAPPING[cam_name] = cam_key
+
+ cam_data = {cam_name: [] for cam_name in CAM_NAME_TO_KEY_MAPPING.keys()}
+ traj_reader = TrajectoryReader(path, read_images=False)
+
+ for index in range(demo_len):
+
+ timestep = traj_reader.read_timestep(index=index)
+ timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
+
+ timestamp_dict = {}
+ camera_obs = camera_reader.read_cameras(
+ index=index, camera_type_dict=CAM_ID_TO_TYPE, timestamp_dict=timestamp_dict
+ )
+ for cam_name in CAM_NAME_TO_KEY_MAPPING.keys():
+ if camera_obs is None:
+ im = np.zeros((args.imsize, args.imsize, 3))
+ else:
+ im_key = CAM_NAME_TO_KEY_MAPPING[cam_name]
+ im = camera_obs["image"][im_key]
+
+ # perform bgr_to_rgb operation
+ im = im[:,:,::-1]
+
+ cam_data[cam_name].append(im)
+
+ for cam_name in cam_data.keys():
+ cam_data[cam_name] = np.array(cam_data[cam_name]).astype(np.uint8)
+ if cam_name in image_grp:
+ del image_grp[cam_name]
+ image_grp.create_dataset(cam_name, data=cam_data[cam_name], compression="gzip")
+
+ # extract action key data
+ action_dict_group = f["action"]
+ for in_ac_key in ["cartesian_position", "cartesian_velocity"]:
+ in_action = action_dict_group[in_ac_key][:]
+ in_pos = in_action[:,:3].astype(np.float64)
+ in_rot = in_action[:,3:6].astype(np.float64)
+ rot_ = torch.from_numpy(in_rot)
+ rot_mat = pt.axis_angle_to_matrix(rot_)
+ rot_6d = pt.matrix_to_rotation_6d(rot_mat).numpy().astype(np.float64)
+
+ if in_ac_key == "cartesian_position":
+ prefix = "abs_"
+ elif in_ac_key == "cartesian_velocity":
+ prefix = "rel_"
+ else:
+ raise ValueError
+
+ this_action_dict = {
+ prefix + 'pos': in_pos,
+ prefix + 'rot_axis_angle': in_rot,
+ prefix + 'rot_6d': rot_6d,
+ }
+ for key, data in this_action_dict.items():
+ if key in action_dict_group:
+ del action_dict_group[key]
+ action_dict_group.create_dataset(key, data=data)
+
+ # ensure all action keys are batched (ie., are not 0-dimensional)
+ for k in action_dict_group:
+ if isinstance(action_dict_group[k], h5py.Dataset) and len(action_dict_group[k].shape) == 1:
+ reshaped_values = np.reshape(action_dict_group[k][:], (-1, 1))
+ del action_dict_group[k]
+ action_dict_group.create_dataset(k, data=reshaped_values)
+
+ f.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--folder",
+ type=str,
+ help="folder containing hdf5's to add camera images to",
+ default="~/datasets/r2d2/success"
+ )
+
+ parser.add_argument(
+ "--imsize",
+ type=int,
+ default=128,
+ help="image size (w and h)",
+ )
+
+ args = parser.parse_args()
+
+ datasets = []
+ for root, dirs, files in os.walk(os.path.expanduser(args.folder)):
+ for f in files:
+ if f == "trajectory.h5":
+ datasets.append(os.path.join(root, f))
+
+ print("converting datasets...")
+ for d in tqdm(datasets):
+ d = os.path.expanduser(d)
+ convert_dataset(d, args)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_robosuite.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_robosuite.py
new file mode 100644
index 0000000000000000000000000000000000000000..8825869824c92791f7bf51058e1e5045ca40b72e
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_robosuite.py
@@ -0,0 +1,75 @@
+"""
+Helper script to convert a dataset collected using robosuite into an hdf5 compatible with
+this repository. Takes a dataset path corresponding to the demo.hdf5 file containing the
+demonstrations. It modifies the dataset in-place. By default, the script also creates a
+90-10 train-validation split.
+
+For more information on collecting datasets with robosuite, see the code link and documentation
+link below.
+
+Code: https://github.com/ARISE-Initiative/robosuite/blob/offline_study/robosuite/scripts/collect_human_demonstrations.py
+
+Documentation: https://robosuite.ai/docs/algorithms/demonstrations.html
+
+Example usage:
+
+ python convert_robosuite.py --dataset /path/to/your/demo.hdf5
+"""
+
+import h5py
+import json
+import argparse
+
+import robomimic.envs.env_base as EB
+from robomimic.scripts.split_train_val import split_train_val_from_hdf5
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="path to input hdf5 dataset",
+ )
+ args = parser.parse_args()
+
+ f = h5py.File(args.dataset, "a") # edit mode
+
+ # store env meta
+ env_name = f["data"].attrs["env"]
+ env_info = json.loads(f["data"].attrs["env_info"])
+ env_meta = dict(
+ type=EB.EnvType.ROBOSUITE_TYPE,
+ env_name=env_name,
+ env_version=f["data"].attrs["repository_version"],
+ env_kwargs=env_info,
+ )
+ if "env_args" in f["data"].attrs:
+ del f["data"].attrs["env_args"]
+ f["data"].attrs["env_args"] = json.dumps(env_meta, indent=4)
+
+ print("====== Stored env meta ======")
+ print(f["data"].attrs["env_args"])
+
+ # store metadata about number of samples
+ total_samples = 0
+ for ep in f["data"]:
+ # ensure model-xml is in per-episode metadata
+ assert "model_file" in f["data/{}".format(ep)].attrs
+
+ # add "num_samples" into per-episode metadata
+ if "num_samples" in f["data/{}".format(ep)].attrs:
+ del f["data/{}".format(ep)].attrs["num_samples"]
+ n_sample = f["data/{}/actions".format(ep)].shape[0]
+ f["data/{}".format(ep)].attrs["num_samples"] = n_sample
+ total_samples += n_sample
+
+ # add total samples to global metadata
+ if "total" in f["data"].attrs:
+ del f["data"].attrs["total"]
+ f["data"].attrs["total"] = total_samples
+
+ f.close()
+
+ # create 90-10 train-validation split in the dataset
+ split_train_val_from_hdf5(hdf5_path=args.dataset, val_ratio=0.1)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_roboturk_pilot.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_roboturk_pilot.py
new file mode 100644
index 0000000000000000000000000000000000000000..2105980453d59be953f3fbe58a3a5ace12a8dccb
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_roboturk_pilot.py
@@ -0,0 +1,192 @@
+"""
+Helper script to convert the RoboTurk Pilot datasets (https://roboturk.stanford.edu/dataset_sim.html)
+into a format compatible with this repository. It will also create some useful filter keys
+in the file (e.g. training, validation, and fastest n trajectories). Prior work
+(https://arxiv.org/abs/1911.05321) has found this useful (for example, training on the
+fastest 225 demonstrations for bins-Can).
+
+Direct download link for dataset: http://cvgl.stanford.edu/projects/roboturk/RoboTurkPilot.zip
+
+Args:
+ folder (str): path to a folder containing a demo.hdf5 and a models directory containing
+ mujoco xml files. For example, RoboTurkPilot/bins-Can.
+
+ n (int): creates a filter key corresponding to the n fastest trajectories. Defaults to 225.
+
+Example usage:
+
+ python convert_roboturk_pilot.py --folder /path/to/RoboTurkPilot/bins-Can --n 225
+"""
+
+import os
+import h5py
+import json
+import argparse
+import numpy as np
+from tqdm import tqdm
+
+import robomimic
+import robomimic.envs.env_base as EB
+from robomimic.utils.file_utils import create_hdf5_filter_key
+from robomimic.scripts.split_train_val import split_train_val_from_hdf5
+
+
+def convert_rt_pilot_hdf5(ref_folder):
+ """
+ Uses the reference demo hdf5 to write a new converted hdf5 compatible with
+ the repository.
+
+ Args:
+ ref_folder (str): path to a folder containing a demo.hdf5 and a models directory containing
+ mujoco xml files.
+ """
+ hdf5_path = os.path.join(ref_folder, "demo.hdf5")
+ new_path = os.path.join(ref_folder, "demo_new.hdf5")
+
+ f = h5py.File(hdf5_path, "r")
+ f_new = h5py.File(new_path, "w")
+ f_new_grp = f_new.create_group("data")
+
+ # sorted list of demonstrations by demo number
+ demos = list(f["data"].keys())
+ inds = np.argsort([int(elem[5:]) for elem in demos])
+ demos = [demos[i] for i in inds]
+
+ # write each demo
+ num_samples_arr = []
+ for demo_id in tqdm(range(len(demos))):
+ ep = demos[demo_id]
+
+ # create group for this demonstration
+ ep_data_grp = f_new_grp.create_group(ep)
+
+ # copy states over
+ states = f["data/{}/states".format(ep)][()]
+ ep_data_grp.create_dataset("states", data=np.array(states))
+
+ # concat jvels and gripper actions to form full actions
+ jvels = f["data/{}/joint_velocities".format(ep)][()]
+ gripper_acts = f["data/{}/gripper_actuations".format(ep)][()]
+ actions = np.concatenate([jvels, gripper_acts], axis=1)
+
+ # IMPORTANT: clip actions to -1, 1, since this is expected by the codebase
+ actions = np.clip(actions, -1., 1.)
+ ep_data_grp.create_dataset("actions", data=actions)
+
+ # store model xml directly in the new hdf5 file
+ model_path = os.path.join(ref_folder, "models", f["data/{}".format(ep)].attrs["model_file"])
+ f_model = open(model_path, "r")
+ model_xml = f_model.read()
+ f_model.close()
+ ep_data_grp.attrs["model_file"] = model_xml
+
+ # store num samples for this ep
+ num_samples = actions.shape[0]
+ ep_data_grp.attrs["num_samples"] = num_samples # number of transitions in this episode
+ num_samples_arr.append(num_samples)
+
+ # write dataset attributes (metadata)
+ f_new_grp.attrs["total"] = np.sum(num_samples_arr)
+
+ # construct and save env metadata
+ env_meta = dict()
+ env_meta["type"] = EB.EnvType.ROBOSUITE_TYPE
+ env_meta["env_name"] = (f["data"].attrs["env"] + "Teleop")
+ # hardcode robosuite v0.3 args
+ robosuite_args = {
+ "has_renderer": False,
+ "has_offscreen_renderer": False,
+ "ignore_done": True,
+ "use_object_obs": True,
+ "use_camera_obs": False,
+ "camera_depth": False,
+ "camera_height": 84,
+ "camera_width": 84,
+ "camera_name": "agentview",
+ "gripper_visualization": False,
+ "reward_shaping": False,
+ "control_freq": 100,
+ }
+ env_meta["env_kwargs"] = robosuite_args
+ f_new_grp.attrs["env_args"] = json.dumps(env_meta, indent=4) # environment info
+
+ print("\n====== Added env meta ======")
+ print(f_new_grp.attrs["env_args"])
+
+ f.close()
+ f_new.close()
+
+ # back up the old dataset, and replace with new dataset
+ os.rename(hdf5_path, os.path.join(ref_folder, "demo_bak.hdf5"))
+ os.rename(new_path, hdf5_path)
+
+
+def split_fastest_from_hdf5(hdf5_path, n):
+ """
+ Creates filter key for fastest N trajectories, named
+ "fastest_{}".format(n).
+
+ Args:
+ hdf5_path (str): path to the hdf5 file
+
+ n (int): fastest n demos to create filter key for
+ """
+
+ # retrieve fastest n demos
+ f = h5py.File(hdf5_path, "r")
+ demos = sorted(list(f["data"].keys()))
+ traj_lengths = []
+ for ep in demos:
+ traj_lengths.append(f["data/{}/actions".format(ep)].shape[0])
+ inds = np.argsort(traj_lengths)[:n]
+ filtered_demos = [demos[i] for i in inds]
+ f.close()
+
+ # create filter key
+ name = "fastest_{}".format(n)
+ lengths = create_hdf5_filter_key(hdf5_path=hdf5_path, demo_keys=filtered_demos, key_name=name)
+
+ print("Total number of samples in fastest {} demos: {}".format(n, np.sum(lengths)))
+ print("Average number of samples in fastest {} demos: {}".format(n, np.mean(lengths)))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--folder",
+ type=str,
+ help="path to a folder containing a demo.hdf5 and a models directory containing \
+ mujoco xml files. For example, RoboTurkPilot/bins-Can.",
+ )
+ parser.add_argument(
+ "--n",
+ type=int,
+ default=225,
+ help="creates a filter key corresponding to the n fastest trajectories. Defaults to 225.",
+ )
+ args = parser.parse_args()
+
+ # convert hdf5
+ convert_rt_pilot_hdf5(ref_folder=args.folder)
+
+ # create 90-10 train-validation split in the dataset
+ print("\nCreating 90-10 train-validation split...\n")
+ hdf5_path = os.path.join(args.folder, "demo.hdf5")
+ split_train_val_from_hdf5(hdf5_path=hdf5_path, val_ratio=0.1)
+
+ print("\nCreating filter key for fastest {} trajectories...".format(args.n))
+ split_fastest_from_hdf5(hdf5_path=hdf5_path, n=args.n)
+
+ print("\nCreating 90-10 train-validation split for fastest {} trajectories...".format(args.n))
+ split_train_val_from_hdf5(hdf5_path=hdf5_path, val_ratio=0.1, filter_key="fastest_{}".format(args.n))
+
+ print(
+ "\nWARNING: new dataset has replaced old one in demo.hdf5 file. "
+ "The old dataset file has been moved to demo_bak.hdf5"
+ )
+
+ print(
+ "\nNOTE: the new dataset also contains a fastest_{} filter key, for an easy way "
+ "to train on the fastest trajectories. Just set config.train.hdf5_filter to train on this "
+ "subset. A common choice is 225 when training on the bins-Can dataset.\n".format(args.n)
+ )
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_to_robosuite_v141.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_to_robosuite_v141.py
new file mode 100644
index 0000000000000000000000000000000000000000..caf694c7abf17d8b28b3ff78bf99bf710d22072b
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/convert_to_robosuite_v141.py
@@ -0,0 +1,156 @@
+import h5py
+import json
+import argparse
+import os
+from shutil import copyfile
+import robosuite
+import xml.etree.ElementTree as ET
+
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.env_utils as EnvUtils
+import robomimic.utils.file_utils as FileUtils
+
+from robosuite.utils.mjcf_utils import find_elements
+
+def replace_elem(parent, old_elem, new_elem):
+ """
+ code adapted from https://stackoverflow.com/a/20931505
+ """
+ parent_index = list(parent).index(old_elem)
+ parent.remove(old_elem)
+ parent.insert(parent_index, new_elem)
+
+def convert_xml(old_xml_str, env_name, env):
+ """
+ Postprocess xml string generated by robosuite to be compatible with robosuite v1.3
+ This script should not the xml string if it was already generated using robosuite v1.3
+ Args:
+ xml_str (str): xml string to process (from robosuite v1.2)
+ """
+
+ if env_name in ["PickPlaceCan", "NutAssemblySquare", "ToolHang"]:
+ xml_str = env.env.sim.model.get_xml()
+ elif env_name == "Lift":
+ xml_str = env.env.sim.model.get_xml()
+ # replace the cube_g0 and cube_g0_vis with elements in old_xml_str
+ old_et = ET.ElementTree(ET.fromstring(old_xml_str)).getroot()
+ new_et = ET.ElementTree(ET.fromstring(xml_str)).getroot()
+
+ cube_new = find_elements(
+ root=new_et,
+ tags="body",
+ attribs={"name": "cube_main"},
+ return_first=True
+ )
+
+ cube_old = find_elements(
+ root=old_et,
+ tags="body",
+ attribs={"name": "cube_main"},
+ return_first=True
+ )
+
+ worldbody_new = find_elements(
+ root=new_et,
+ tags="worldbody",
+ return_first=True
+ )
+
+ replace_elem(worldbody_new, cube_new, cube_old)
+
+ xml_str = ET.tostring(new_et, encoding="utf8").decode("utf8")
+ elif env_name == "TwoArmTransport":
+ xml_str = env.env.sim.model.get_xml()
+ # replace the cube_g0 and cube_g0_vis with elements in old_xml_str
+ old_et = ET.ElementTree(ET.fromstring(old_xml_str)).getroot()
+ new_et = ET.ElementTree(ET.fromstring(xml_str)).getroot()
+
+ worldbody_new = find_elements(
+ root=new_et,
+ tags="worldbody",
+ return_first=True
+ )
+ for bname in [
+ "payload_root",
+
+ ### ignore all these other following assets (makes playback worse for some reason...)
+ # "trash_main",
+ # "transport_start_bin_root", "transport_target_bin_root",
+ # "transport_trash_bin_root", "transport_start_bin_lid_root"
+ ]:
+ body_new = find_elements(
+ root=new_et,
+ tags="body",
+ attribs={"name": bname},
+ return_first=True
+ )
+
+ body_old = find_elements(
+ root=old_et,
+ tags="body",
+ attribs={"name": bname},
+ return_first=True
+ )
+
+ replace_elem(worldbody_new, body_new, body_old)
+
+ xml_str = ET.tostring(new_et, encoding="utf8").decode("utf8")
+
+ return xml_str
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="path to input hdf5 dataset",
+ )
+ parser.add_argument(
+ "--output_dataset",
+ type=str,
+ help="path to output hdf5 dataset",
+ )
+ args = parser.parse_args()
+
+ args.dataset = os.path.expanduser(args.dataset)
+ args.output_dataset = os.path.expanduser(args.output_dataset)
+
+ assert args.output_dataset != args.dataset
+ assert robosuite.__version__ == '1.4.1'
+
+ copyfile(args.dataset, args.output_dataset)
+
+ f = h5py.File(args.output_dataset, "r+")
+
+ env_args = json.loads(f["data"].attrs["env_args"])
+ env_name = env_args["env_name"]
+
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
+ env_type = EnvUtils.get_env_type(env_meta=env_meta)
+
+ # need to make sure ObsUtils knows which observations are images, but it doesn't matter
+ # for playback since observations are unused. Pass a dummy spec here.
+ dummy_spec = dict(
+ obs=dict(
+ low_dim=["robot0_eef_pos"],
+ rgb=[],
+ ),
+ )
+ ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)
+
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
+ env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=False, render_offscreen=True)
+ env.reset()
+
+ for demo_key in list(f["data"].keys()):
+ ep_data_grp = f["data/{}".format(demo_key)]
+ model_file = ep_data_grp.attrs["model_file"]
+
+ coverted_model_file = convert_xml(model_file, env_name, env)
+ ep_data_grp.attrs["model_file"] = coverted_model_file
+
+ env_args = json.loads(f["data"].attrs["env_args"])
+ env_args["env_version"] = robosuite.__version__
+ f["data"].attrs["env_args"] = json.dumps(env_args, indent=4)
+
+ f.close()
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/extract_action_dict.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/extract_action_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..798f59263e03c32fb806b41f7a3a07aa18152631
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/extract_action_dict.py
@@ -0,0 +1,71 @@
+import argparse
+import pathlib
+import sys
+import tqdm
+import h5py
+import numpy as np
+import torch
+import os
+
+def extract_action_dict(args):
+ # find files
+ f = h5py.File(os.path.expanduser(args.dataset), mode="r+")
+
+ SPECS = [
+ dict(
+ key="actions",
+ is_absolute=False,
+ ),
+ dict(
+ key="actions_abs",
+ is_absolute=True,
+ )
+ ]
+
+ # execute
+ for spec in SPECS:
+ input_action_key = spec["key"]
+ is_absolute = spec["is_absolute"]
+
+ if is_absolute:
+ prefix = "abs_"
+ else:
+ prefix = "rel_"
+
+ for demo in f['data'].values():
+ in_action = demo[str(input_action_key)][:]
+ in_pos = in_action[:,:3].astype(np.float32)
+ in_rot = in_action[:,3:6].astype(np.float32)
+ in_grip = in_action[:,6:].astype(np.float32)
+
+ rot_ = torch.from_numpy(in_rot)
+ rot_6d = TorchUtils.axis_angle_to_rot_6d(rot_).numpy().astype(np.float32)
+
+ this_action_dict = {
+ prefix + 'pos': in_pos,
+ prefix + 'rot_axis_angle': in_rot,
+ prefix + 'rot_6d': rot_6d,
+ 'gripper': in_grip
+ }
+ # if 'action_dict' in demo:
+ # del demo['action_dict']
+ action_dict_group = demo.require_group('action_dict')
+ for key, data in this_action_dict.items():
+ if key in action_dict_group:
+ del action_dict_group[key]
+ action_dict_group.create_dataset(key, data=data)
+
+ f.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ required=True
+ )
+
+ args = parser.parse_args()
+
+ extract_action_dict(args)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/robosuite_add_absolute_actions.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/robosuite_add_absolute_actions.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d1565d00b97d513ea1ea41c5fb3bedf85923560
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/robosuite_add_absolute_actions.py
@@ -0,0 +1,290 @@
+if __name__ == "__main__":
+ import sys
+ import os
+ import pathlib
+
+ ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
+ sys.path.append(ROOT_DIR)
+
+import multiprocessing
+import os
+import shutil
+import click
+import pathlib
+import h5py
+from tqdm import tqdm
+import collections
+import pickle
+
+
+
+import numpy as np
+import copy
+
+import h5py
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.file_utils as FileUtils
+import robomimic.utils.env_utils as EnvUtils
+from scipy.spatial.transform import Rotation
+
+from robomimic.config import config_factory
+
+"""
+copied/adapted from https://github.com/columbia-ai-robotics/diffusion_policy/blob/main/diffusion_policy/common/robomimic_util.py
+"""
+class RobomimicAbsoluteActionConverter:
+ def __init__(self, dataset_path, algo_name='bc'):
+ # default BC config
+ config = config_factory(algo_name=algo_name)
+
+ # read config to set up metadata for observation modalities (e.g. detecting rgb observations)
+ # must ran before create dataset
+ ObsUtils.initialize_obs_utils_with_config(config)
+
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path)
+ abs_env_meta = copy.deepcopy(env_meta)
+ abs_env_meta['env_kwargs']['controller_configs']['control_delta'] = False
+
+ env = EnvUtils.create_env_from_metadata(
+ env_meta=env_meta,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=False,
+ )
+ assert len(env.env.robots) in (1, 2)
+
+ abs_env = EnvUtils.create_env_from_metadata(
+ env_meta=abs_env_meta,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=False,
+ )
+ assert not abs_env.env.robots[0].controller.use_delta
+
+ self.env = env
+ self.abs_env = abs_env
+ self.file = h5py.File(dataset_path, 'r')
+
+ def __len__(self):
+ return len(self.file['data'])
+
+ def convert_actions(self,
+ states: np.ndarray,
+ actions: np.ndarray) -> np.ndarray:
+ """
+ Given state and delta action sequence
+ generate equivalent goal position and orientation for each step
+ keep the original gripper action intact.
+ """
+ # in case of multi robot
+ # reshape (N,14) to (N,2,7)
+ # or (N,7) to (N,1,7)
+ stacked_actions = actions.reshape(*actions.shape[:-1],-1,7)
+
+ env = self.env
+ # generate abs actions
+ action_goal_pos = np.zeros(
+ stacked_actions.shape[:-1]+(3,),
+ dtype=stacked_actions.dtype)
+ action_goal_ori = np.zeros(
+ stacked_actions.shape[:-1]+(3,),
+ dtype=stacked_actions.dtype)
+ action_gripper = stacked_actions[...,[-1]]
+ for i in range(len(states)):
+ _ = env.reset_to({'states': states[i]})
+
+ # taken from robot_env.py L#454
+ for idx, robot in enumerate(env.env.robots):
+ # run controller goal generator
+ robot.control(stacked_actions[i,idx], policy_step=True)
+
+ # read pos and ori from robots
+ controller = robot.controller
+ action_goal_pos[i,idx] = controller.goal_pos
+ action_goal_ori[i,idx] = Rotation.from_matrix(
+ controller.goal_ori).as_rotvec()
+
+ stacked_abs_actions = np.concatenate([
+ action_goal_pos,
+ action_goal_ori,
+ action_gripper
+ ], axis=-1)
+ abs_actions = stacked_abs_actions.reshape(actions.shape)
+ return abs_actions
+
+ def convert_idx(self, idx):
+ file = self.file
+ demo = file[f'data/demo_{idx}']
+ # input
+ states = demo['states'][:]
+ actions = demo['actions'][:]
+
+ # generate abs actions
+ abs_actions = self.convert_actions(states, actions)
+ return abs_actions
+
+ def convert_and_eval_idx(self, idx):
+ env = self.env
+ abs_env = self.abs_env
+ file = self.file
+ # first step have high error for some reason, not representative
+ eval_skip_steps = 1
+
+ demo = file[f'data/demo_{idx}']
+ # input
+ states = demo['states'][:]
+ actions = demo['actions'][:]
+
+ # generate abs actions
+ abs_actions = self.convert_actions(states, actions)
+
+ # verify
+ robot0_eef_pos = demo['obs']['robot0_eef_pos'][:]
+ robot0_eef_quat = demo['obs']['robot0_eef_quat'][:]
+
+ delta_error_info = self.evaluate_rollout_error(
+ env, states, actions, robot0_eef_pos, robot0_eef_quat,
+ metric_skip_steps=eval_skip_steps)
+ abs_error_info = self.evaluate_rollout_error(
+ abs_env, states, abs_actions, robot0_eef_pos, robot0_eef_quat,
+ metric_skip_steps=eval_skip_steps)
+
+ info = {
+ 'delta_max_error': delta_error_info,
+ 'abs_max_error': abs_error_info
+ }
+ return abs_actions, info
+
+ @staticmethod
+ def evaluate_rollout_error(env,
+ states, actions,
+ robot0_eef_pos,
+ robot0_eef_quat,
+ metric_skip_steps=1):
+ # first step have high error for some reason, not representative
+
+ # evaluate abs actions
+ rollout_next_states = list()
+ rollout_next_eef_pos = list()
+ rollout_next_eef_quat = list()
+ obs = env.reset_to({'states': states[0]})
+ for i in range(len(states)):
+ obs = env.reset_to({'states': states[i]})
+ obs, reward, done, info = env.step(actions[i])
+ obs = env.get_observation()
+ rollout_next_states.append(env.get_state()['states'])
+ rollout_next_eef_pos.append(obs['robot0_eef_pos'])
+ rollout_next_eef_quat.append(obs['robot0_eef_quat'])
+ rollout_next_states = np.array(rollout_next_states)
+ rollout_next_eef_pos = np.array(rollout_next_eef_pos)
+ rollout_next_eef_quat = np.array(rollout_next_eef_quat)
+
+ next_state_diff = states[1:] - rollout_next_states[:-1]
+ max_next_state_diff = np.max(np.abs(next_state_diff[metric_skip_steps:]))
+
+ next_eef_pos_diff = robot0_eef_pos[1:] - rollout_next_eef_pos[:-1]
+ next_eef_pos_dist = np.linalg.norm(next_eef_pos_diff, axis=-1)
+ max_next_eef_pos_dist = next_eef_pos_dist[metric_skip_steps:].max()
+
+ next_eef_rot_diff = Rotation.from_quat(robot0_eef_quat[1:]) \
+ * Rotation.from_quat(rollout_next_eef_quat[:-1]).inv()
+ next_eef_rot_dist = next_eef_rot_diff.magnitude()
+ max_next_eef_rot_dist = next_eef_rot_dist[metric_skip_steps:].max()
+
+ info = {
+ 'state': max_next_state_diff,
+ 'pos': max_next_eef_pos_dist,
+ 'rot': max_next_eef_rot_dist
+ }
+ return info
+
+"""
+copied/adapted from https://github.com/columbia-ai-robotics/diffusion_policy/blob/main/diffusion_policy/scripts/robomimic_dataset_conversion.py
+"""
+def worker(x):
+ path, idx, do_eval = x
+ converter = RobomimicAbsoluteActionConverter(path)
+ if do_eval:
+ abs_actions, info = converter.convert_and_eval_idx(idx)
+ else:
+ abs_actions = converter.convert_idx(idx)
+ info = dict()
+ return abs_actions, info
+
+@click.command()
+@click.option('-i', '--input', required=True, help='input hdf5 path')
+@click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
+@click.option('-e', '--eval_dir', default=None, help='directory to output evaluation metrics')
+@click.option('-n', '--num_workers', default=None, type=int)
+def main(input, output, eval_dir, num_workers):
+ # process inputs
+ input = pathlib.Path(input).expanduser()
+ assert input.is_file()
+ output = pathlib.Path(output).expanduser()
+ assert output.parent.is_dir()
+ assert not output.is_dir()
+
+ do_eval = False
+ if eval_dir is not None:
+ eval_dir = pathlib.Path(eval_dir).expanduser()
+ assert eval_dir.parent.exists()
+ do_eval = True
+
+ converter = RobomimicAbsoluteActionConverter(input)
+
+ # run
+ with multiprocessing.Pool(num_workers) as pool:
+ results = pool.map(worker, [(input, i, do_eval) for i in range(len(converter))])
+
+ # save output
+ print('Copying hdf5')
+ shutil.copy(str(input), str(output))
+
+ # modify action
+ with h5py.File(output, 'r+') as out_file:
+ for i in tqdm(range(len(converter)), desc="Writing to output"):
+ abs_actions, info = results[i]
+ demo = out_file[f'data/demo_{i}']
+ if "actions_abs" not in demo:
+ demo.create_dataset("actions_abs", data=np.array(abs_actions))
+ else:
+ demo['actions_abs'][:] = abs_actions
+
+ # save eval
+ if do_eval:
+ eval_dir.mkdir(parents=False, exist_ok=True)
+
+ print("Writing error_stats.pkl")
+ infos = [info for _, info in results]
+ pickle.dump(infos, eval_dir.joinpath('error_stats.pkl').open('wb'))
+
+ print("Generating visualization")
+ metrics = ['pos', 'rot']
+ metrics_dicts = dict()
+ for m in metrics:
+ metrics_dicts[m] = collections.defaultdict(list)
+
+ for i in range(len(infos)):
+ info = infos[i]
+ for k, v in info.items():
+ for m in metrics:
+ metrics_dicts[m][k].append(v[m])
+
+ from matplotlib import pyplot as plt
+ plt.switch_backend('PDF')
+
+ fig, ax = plt.subplots(1, len(metrics))
+ for i in range(len(metrics)):
+ axis = ax[i]
+ data = metrics_dicts[metrics[i]]
+ for key, value in data.items():
+ axis.plot(value, label=key)
+ axis.legend()
+ axis.set_title(metrics[i])
+ fig.set_size_inches(10,4)
+ fig.savefig(str(eval_dir.joinpath('error_stats.pdf')))
+ fig.savefig(str(eval_dir.joinpath('error_stats.png')))
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/set_dataset_attr.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/set_dataset_attr.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f148d08f0bc9ef6336c919ee7cf4e639aded8b5
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/conversion/set_dataset_attr.py
@@ -0,0 +1,98 @@
+"""
+Example:
+python robomimic/scripts/set_dataset_attr.py --glob 'datasets/**/*_abs.hdf5' --env_args env_kwargs.controller_configs.control_delta=false absolute_actions=true
+"""
+import argparse
+import pathlib
+import json
+import sys
+import tqdm
+import h5py
+
+def update_env_args_dict(env_args_dict: dict, key: tuple, value):
+ if key is None:
+ return env_args_dict
+ elif len(key) == 0:
+ return env_args_dict
+ elif len(key) == 1:
+ env_args_dict[key[0]] = value
+ return env_args_dict
+ else:
+ this_key = key[0]
+ if this_key not in env_args_dict:
+ env_args_dict[this_key] = dict()
+ update_env_args_dict(env_args_dict[this_key], key[1:], value)
+ return env_args_dict
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--glob",
+ type=str,
+ required=True
+ )
+
+ parser.add_argument(
+ "--env_args",
+ type=str,
+ default=None
+ )
+
+ parser.add_argument(
+ 'attrs',
+ nargs='*'
+ )
+
+ args = parser.parse_args()
+
+ # parse attrs to set
+ # format: key=value
+ # values are parsed with json
+ attrs_dict = dict()
+ for attr_arg in args.attrs:
+ key, svalue = attr_arg.split("=")
+ value = json.loads(svalue)
+ attrs_dict[key] = value
+
+ # parse env_args update
+ env_args_key = None
+ env_args_value = None
+ if args.env_args is not None:
+ key, svalue = args.env_args.split('=')
+ env_args_key = key.split('.')
+ env_args_value = json.loads(svalue)
+
+ # find files
+ file_paths = list(pathlib.Path.cwd().glob(args.glob))
+
+ # confirm with the user
+ print("Found matching files:")
+ for f in file_paths:
+ print(f)
+ print("Are you sure to modify these files with the following attributes:")
+ print(json.dumps(attrs_dict, indent=2))
+ if env_args_key is not None:
+ print("env_args."+'.'.join(env_args_key)+'='+str(env_args_value))
+ result = input("[y/n]?")
+ if 'y' not in result:
+ sys.exit(0)
+
+ # execute
+ for file_path in tqdm.tqdm(file_paths):
+ with h5py.File(str(file_path), mode='r+') as file:
+ # update env_args
+ if env_args_key is not None:
+ env_args = file['data'].attrs['env_args']
+ env_args_dict = json.loads(env_args)
+ env_args_dict = update_env_args_dict(
+ env_args_dict=env_args_dict,
+ key=env_args_key, value=env_args_value)
+ env_args = json.dumps(env_args_dict)
+ file['data'].attrs['env_args'] = env_args
+
+ # update other attrs
+ file['data'].attrs.update(attrs_dict)
+
+if __name__ == "__main__":
+ main()
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/convert_actions.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/convert_actions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0048ac44a47af194e641e2e1220b66030ef869b
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/convert_actions.py
@@ -0,0 +1,89 @@
+"""
+Helper script to prepare datasets for diffusion policy training by (1) adding absolute actions and (2)
+writing the absolute actions to action dictionaries.
+"""
+import os
+import h5py
+import argparse
+import socket
+import json
+import numpy as np
+
+import robomimic
+import robomimic.macros as Macros
+from robomimic.scripts.conversion.extract_action_dict import extract_action_dict
+
+import mimicgen
+from mimicgen.scripts.add_datagen_info import add_datagen_info
+
+DATASETS = [
+ "/tmp/coffee/src_10.hdf5",
+ "/tmp/stack/src_10.hdf5",
+]
+
+
+def convert_actions_in_dataset(dataset_path, output_name=None, absolute_mg=False):
+ """
+ Helper function to call the relevant scripts to get absolute action dicts for a given dataset.
+ """
+
+ # first get absolute actions
+ args = argparse.Namespace()
+ args.dataset = dataset_path
+ args.n = None
+ args.absolute = True
+ args.absolute_mg = absolute_mg
+
+ new_ds_path = dataset_path
+ if output_name is not None:
+ args.output = os.path.join(os.path.dirname(dataset_path), output_name)
+ new_ds_path = args.output
+ else:
+ args.output = None
+ add_datagen_info(args)
+
+ # next convert actions to dict
+ args = argparse.Namespace()
+ args.dataset = new_ds_path
+ extract_action_dict(args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--datasets",
+ type=str,
+ nargs='+',
+ default=None,
+ )
+ parser.add_argument(
+ "--output_name",
+ type=str,
+ default=None,
+ )
+ parser.add_argument(
+ "--absolute_mg",
+ action='store_true',
+ help="extract absolute actions using existing datagen info, and skip extraction of datagen info",
+ )
+ parser.add_argument(
+ "--slack",
+ action='store_true',
+ help="try to give slack notification after script finishes",
+ )
+ args = parser.parse_args()
+
+ datasets = args.datasets
+ if datasets is None:
+ datasets = DATASETS
+
+ for d in datasets:
+ dataset_path = os.path.expanduser(d)
+ convert_actions_in_dataset(dataset_path, output_name=args.output_name, absolute_mg=args.absolute_mg)
+
+ if args.slack and (Macros.SLACK_TOKEN is not None):
+ from robomimic.scripts.give_slack_notification import give_slack_notif
+ msg = "Completed the following action conversion run!\nHostname: {}\n".format(socket.gethostname())
+ datasets_json = json.dumps(dict(datasets=datasets), indent=4)
+ msg += "```{}```".format(datasets_json)
+ give_slack_notif(msg)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/dataset_states_to_obs.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/dataset_states_to_obs.py
new file mode 100644
index 0000000000000000000000000000000000000000..6295d4f72dd6a8f11a4dd8a16f06fd7e88d2dc9c
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/dataset_states_to_obs.py
@@ -0,0 +1,425 @@
+"""
+Script to extract observations from low-dimensional simulation states in a robosuite dataset.
+
+Args:
+ dataset (str): path to input hdf5 dataset
+
+ output_name (str): name of output hdf5 dataset
+
+ n (int): if provided, stop after n trajectories are processed
+
+ shaped (bool): if flag is set, use dense rewards
+
+ camera_names (str or [str]): camera name(s) to use for image observations.
+ Leave out to not use image observations.
+
+ camera_height (int): height of image observation.
+
+ camera_width (int): width of image observation
+
+ done_mode (int): how to write done signal. If 0, done is 1 whenever s' is a success state.
+ If 1, done is 1 at the end of each trajectory. If 2, both.
+
+ copy_rewards (bool): if provided, copy rewards from source file instead of inferring them
+
+ copy_dones (bool): if provided, copy dones from source file instead of inferring them
+
+Example usage:
+
+ # extract low-dimensional observations
+ python dataset_states_to_obs.py --dataset /path/to/demo.hdf5 --output_name low_dim.hdf5 --done_mode 2
+
+ # extract 84x84 image observations
+ python dataset_states_to_obs.py --dataset /path/to/demo.hdf5 --output_name image.hdf5 \
+ --done_mode 2 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+ # extract 84x84 image and depth observations
+ python dataset_states_to_obs.py --dataset /path/to/demo.hdf5 --output_name depth.hdf5 \
+ --done_mode 2 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84 --depth
+
+ # (space saving option) extract 84x84 image observations with compression and without
+ # extracting next obs (not needed for pure imitation learning algos)
+ python dataset_states_to_obs.py --dataset /path/to/demo.hdf5 --output_name image.hdf5 \
+ --done_mode 2 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84 \
+ --compress --exclude-next-obs
+
+ # use dense rewards, and only annotate the end of trajectories with done signal
+ python dataset_states_to_obs.py --dataset /path/to/demo.hdf5 --output_name image_dense_done_1.hdf5 \
+ --done_mode 1 --dense --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+"""
+import os
+import json
+import h5py
+import argparse
+import numpy as np
+from copy import deepcopy
+from tqdm import tqdm
+
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.file_utils as FileUtils
+import robomimic.utils.env_utils as EnvUtils
+from robomimic.envs.env_base import EnvBase
+
+
+def extract_trajectory(
+ env,
+ initial_state,
+ states,
+ actions,
+ done_mode,
+ camera_names=None,
+ camera_height=84,
+ camera_width=84,
+):
+ """
+ Helper function to extract observations, rewards, and dones along a trajectory using
+ the simulator environment.
+
+ Args:
+ env (instance of EnvBase): environment
+ initial_state (dict): initial simulation state to load
+ states (np.array): array of simulation states to load to extract information
+ actions (np.array): array of actions
+ done_mode (int): how to write done signal. If 0, done is 1 whenever s' is a
+ success state. If 1, done is 1 at the end of each trajectory.
+ If 2, do both.
+ """
+ assert isinstance(env, EnvBase)
+ assert states.shape[0] == actions.shape[0]
+
+ # load the initial state
+ env.reset()
+ obs = env.reset_to(initial_state)
+
+ # maybe add in intrinsics and extrinsics for all cameras
+ camera_info = None
+ is_robosuite_env = EnvUtils.is_robosuite_env(env=env)
+ if is_robosuite_env:
+ camera_info = get_camera_info(
+ env=env,
+ camera_names=camera_names,
+ camera_height=camera_height,
+ camera_width=camera_width,
+ )
+
+ traj = dict(
+ obs=[],
+ next_obs=[],
+ rewards=[],
+ dones=[],
+ actions=np.array(actions),
+ states=np.array(states),
+ initial_state_dict=initial_state,
+ )
+ traj_len = states.shape[0]
+ # iteration variable @t is over "next obs" indices
+ for t in range(1, traj_len + 1):
+
+ # get next observation
+ if t == traj_len:
+ # play final action to get next observation for last timestep
+ next_obs, _, _, _ = env.step(actions[t - 1])
+ else:
+ # reset to simulator state to get observation
+ next_obs = env.reset_to({"states" : states[t]})
+
+ # infer reward signal
+ # note: our tasks use reward r(s'), reward AFTER transition, so this is
+ # the reward for the current timestep
+ r = env.get_reward()
+
+ # infer done signal
+ done = False
+ if (done_mode == 1) or (done_mode == 2):
+ # done = 1 at end of trajectory
+ done = done or (t == traj_len)
+ if (done_mode == 0) or (done_mode == 2):
+ # done = 1 when s' is task success state
+ done = done or env.is_success()["task"]
+ done = int(done)
+
+ # collect transition
+ traj["obs"].append(obs)
+ traj["next_obs"].append(next_obs)
+ traj["rewards"].append(r)
+ traj["dones"].append(done)
+
+ # update for next iter
+ obs = deepcopy(next_obs)
+
+ # convert list of dict to dict of list for obs dictionaries (for convenient writes to hdf5 dataset)
+ traj["obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["obs"])
+ traj["next_obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["next_obs"])
+
+ # list to numpy array
+ for k in traj:
+ if k == "initial_state_dict":
+ continue
+ if isinstance(traj[k], dict):
+ for kp in traj[k]:
+ traj[k][kp] = np.array(traj[k][kp])
+ else:
+ traj[k] = np.array(traj[k])
+
+ return traj, camera_info
+
+
+def get_camera_info(
+ env,
+ camera_names=None,
+ camera_height=84,
+ camera_width=84,
+):
+ """
+ Helper function to get camera intrinsics and extrinsics for cameras being used for observations.
+ """
+
+ # TODO: make this function more general than just robosuite environments
+ assert EnvUtils.is_robosuite_env(env=env)
+
+ if camera_names is None:
+ return None
+
+ camera_info = dict()
+ for cam_name in camera_names:
+ K = env.get_camera_intrinsic_matrix(camera_name=cam_name, camera_height=camera_height, camera_width=camera_width)
+ R = env.get_camera_extrinsic_matrix(camera_name=cam_name) # camera pose in world frame
+ if "eye_in_hand" in cam_name:
+ # convert extrinsic matrix to be relative to robot eef control frame
+ assert cam_name.startswith("robot0")
+ eef_site_name = env.base_env.robots[0].controller.eef_name
+ eef_pos = np.array(env.base_env.sim.data.site_xpos[env.base_env.sim.model.site_name2id(eef_site_name)])
+ eef_rot = np.array(env.base_env.sim.data.site_xmat[env.base_env.sim.model.site_name2id(eef_site_name)].reshape([3, 3]))
+ eef_pose = np.zeros((4, 4)) # eef pose in world frame
+ eef_pose[:3, :3] = eef_rot
+ eef_pose[:3, 3] = eef_pos
+ eef_pose[3, 3] = 1.0
+ eef_pose_inv = np.zeros((4, 4))
+ eef_pose_inv[:3, :3] = eef_pose[:3, :3].T
+ eef_pose_inv[:3, 3] = -eef_pose_inv[:3, :3].dot(eef_pose[:3, 3])
+ eef_pose_inv[3, 3] = 1.0
+ R = R.dot(eef_pose_inv) # T_E^W * T_W^C = T_E^C
+ camera_info[cam_name] = dict(
+ intrinsics=K.tolist(),
+ extrinsics=R.tolist(),
+ )
+ return camera_info
+
+
+def dataset_states_to_obs(args):
+ if args.depth:
+ assert len(args.camera_names) > 0, "must specify camera names if using depth"
+
+ # create environment to use for data processing
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
+ env = EnvUtils.create_env_for_data_processing(
+ env_meta=env_meta,
+ camera_names=args.camera_names,
+ camera_height=args.camera_height,
+ camera_width=args.camera_width,
+ reward_shaping=args.shaped,
+ use_depth_obs=args.depth,
+ )
+
+ print("==== Using environment with the following metadata ====")
+ print(json.dumps(env.serialize(), indent=4))
+ print("")
+
+ # some operations for playback are robosuite-specific, so determine if this environment is a robosuite env
+ is_robosuite_env = EnvUtils.is_robosuite_env(env_meta)
+
+ # list of all demonstration episodes (sorted in increasing number order)
+ f = h5py.File(args.dataset, "r")
+ demos = list(f["data"].keys())
+ inds = np.argsort([int(elem[5:]) for elem in demos])
+ demos = [demos[i] for i in inds]
+
+ # maybe reduce the number of demonstrations to playback
+ if args.n is not None:
+ demos = demos[:args.n]
+
+ # output file in same directory as input file
+ output_path = os.path.join(os.path.dirname(args.dataset), args.output_name)
+ f_out = h5py.File(output_path, "w")
+ data_grp = f_out.create_group("data")
+ print("input file: {}".format(args.dataset))
+ print("output file: {}".format(output_path))
+
+ total_samples = 0
+ for ind in tqdm(range(len(demos))):
+ ep = demos[ind]
+
+ # prepare initial state to reload from
+ states = f["data/{}/states".format(ep)][()]
+ initial_state = dict(states=states[0])
+ if is_robosuite_env:
+ initial_state["model"] = f["data/{}".format(ep)].attrs["model_file"]
+
+ # extract obs, rewards, dones
+ actions = f["data/{}/actions".format(ep)][()]
+ traj, camera_info = extract_trajectory(
+ env=env,
+ initial_state=initial_state,
+ states=states,
+ actions=actions,
+ done_mode=args.done_mode,
+ camera_names=args.camera_names,
+ camera_height=args.camera_height,
+ camera_width=args.camera_width,
+ )
+
+ # maybe copy reward or done signal from source file
+ if args.copy_rewards:
+ traj["rewards"] = f["data/{}/rewards".format(ep)][()]
+ if args.copy_dones:
+ traj["dones"] = f["data/{}/dones".format(ep)][()]
+
+ # store transitions
+
+ # IMPORTANT: keep name of group the same as source file, to make sure that filter keys are
+ # consistent as well
+ ep_data_grp = data_grp.create_group(ep)
+ ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]))
+ ep_data_grp.create_dataset("states", data=np.array(traj["states"]))
+ ep_data_grp.create_dataset("rewards", data=np.array(traj["rewards"]))
+ ep_data_grp.create_dataset("dones", data=np.array(traj["dones"]))
+ for k in traj["obs"]:
+ if args.compress:
+ ep_data_grp.create_dataset("obs/{}".format(k), data=np.array(traj["obs"][k]), compression="gzip")
+ else:
+ ep_data_grp.create_dataset("obs/{}".format(k), data=np.array(traj["obs"][k]))
+ if not args.exclude_next_obs:
+ if args.compress:
+ ep_data_grp.create_dataset("next_obs/{}".format(k), data=np.array(traj["next_obs"][k]), compression="gzip")
+ else:
+ ep_data_grp.create_dataset("next_obs/{}".format(k), data=np.array(traj["next_obs"][k]))
+
+ # episode metadata
+ if is_robosuite_env:
+ ep_data_grp.attrs["model_file"] = traj["initial_state_dict"]["model"] # model xml for this episode
+ ep_data_grp.attrs["num_samples"] = traj["actions"].shape[0] # number of transitions in this episode
+
+ if camera_info is not None:
+ assert is_robosuite_env
+ ep_data_grp.attrs["camera_info"] = json.dumps(camera_info, indent=4)
+
+ total_samples += traj["actions"].shape[0]
+
+
+ # copy over all filter keys that exist in the original hdf5
+ if "mask" in f:
+ f.copy("mask", f_out)
+
+ # global metadata
+ data_grp.attrs["total"] = total_samples
+ data_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4) # environment info
+ print("Wrote {} trajectories to {}".format(len(demos), output_path))
+
+ f.close()
+ f_out.close()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ required=True,
+ help="path to input hdf5 dataset",
+ )
+ # name of hdf5 to write - it will be in the same directory as @dataset
+ parser.add_argument(
+ "--output_name",
+ type=str,
+ required=True,
+ help="name of output hdf5 dataset",
+ )
+
+ # specify number of demos to process - useful for debugging conversion with a handful
+ # of trajectories
+ parser.add_argument(
+ "--n",
+ type=int,
+ default=None,
+ help="(optional) stop after n trajectories are processed",
+ )
+
+ # flag for reward shaping
+ parser.add_argument(
+ "--shaped",
+ action='store_true',
+ help="(optional) use shaped rewards",
+ )
+
+ # camera names to use for observations
+ parser.add_argument(
+ "--camera_names",
+ type=str,
+ nargs='+',
+ default=[],
+ help="(optional) camera name(s) to use for image observations. Leave out to not use image observations.",
+ )
+
+ parser.add_argument(
+ "--camera_height",
+ type=int,
+ default=84,
+ help="(optional) height of image observations",
+ )
+
+ parser.add_argument(
+ "--camera_width",
+ type=int,
+ default=84,
+ help="(optional) width of image observations",
+ )
+
+ # flag for including depth observations per camera
+ parser.add_argument(
+ "--depth",
+ action='store_true',
+ help="(optional) use depth observations for each camera",
+ )
+
+ # specifies how the "done" signal is written. If "0", then the "done" signal is 1 wherever
+ # the transition (s, a, s') has s' in a task completion state. If "1", the "done" signal
+ # is one at the end of every trajectory. If "2", the "done" signal is 1 at task completion
+ # states for successful trajectories and 1 at the end of all trajectories.
+ parser.add_argument(
+ "--done_mode",
+ type=int,
+ default=0,
+ help="how to write done signal. If 0, done is 1 whenever s' is a success state.\
+ If 1, done is 1 at the end of each trajectory. If 2, both.",
+ )
+
+ # flag for copying rewards from source file instead of re-writing them
+ parser.add_argument(
+ "--copy_rewards",
+ action='store_true',
+ help="(optional) copy rewards from source file instead of inferring them",
+ )
+
+ # flag for copying dones from source file instead of re-writing them
+ parser.add_argument(
+ "--copy_dones",
+ action='store_true',
+ help="(optional) copy dones from source file instead of inferring them",
+ )
+
+ # flag to exclude next obs in dataset
+ parser.add_argument(
+ "--exclude-next-obs",
+ action='store_true',
+ help="(optional) exclude next obs in dataset",
+ )
+
+ # flag to compress observations with gzip option in hdf5
+ parser.add_argument(
+ "--compress",
+ action='store_true',
+ help="(optional) compress observations with gzip option in hdf5",
+ )
+
+ args = parser.parse_args()
+ dataset_states_to_obs(args)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/download_datasets.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/download_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..caf3a280a14aec6f3c39157e9f9d84dd2a2486c4
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/download_datasets.py
@@ -0,0 +1,163 @@
+"""
+Script to download datasets packaged with the repository. By default, all
+datasets will be stored at robomimic/datasets, unless the @download_dir
+argument is supplied. We recommend using the default, as most examples that
+use these datasets assume that they can be found there.
+
+The @tasks, @dataset_types, and @hdf5_types arguments can all be supplied
+to choose which datasets to download.
+
+Args:
+ download_dir (str): Base download directory. Created if it doesn't exist.
+ Defaults to datasets folder in repository - only pass in if you would
+ like to override the location.
+
+ tasks (list): Tasks to download datasets for. Defaults to lift task. Pass 'all' to
+ download all tasks (sim + real) 'sim' to download all sim tasks, 'real' to
+ download all real tasks, or directly specify the list of tasks.
+
+ dataset_types (list): Dataset types to download datasets for (e.g. ph, mh, mg).
+ Defaults to ph. Pass 'all' to download datasets for all available dataset
+ types per task, or directly specify the list of dataset types.
+
+ hdf5_types (list): hdf5 types to download datasets for (e.g. raw, low_dim, image).
+ Defaults to low_dim. Pass 'all' to download datasets for all available hdf5
+ types per task and dataset, or directly specify the list of hdf5 types.
+
+Example usage:
+
+ # default behavior - just download lift proficient-human low-dim dataset
+ python download_datasets.py
+
+ # download low-dim proficient-human datasets for all simulation tasks
+ # (do a dry run first to see which datasets would be downloaded)
+ python download_datasets.py --tasks sim --dataset_types ph --hdf5_types low_dim --dry_run
+ python download_datasets.py --tasks sim --dataset_types ph --hdf5_types low_dim
+
+ # download all low-dim and image multi-human datasets for the can and square tasks
+ python download_datasets.py --tasks can square --dataset_types mh --hdf5_types low_dim image
+
+ # download the sparse reward machine-generated low-dim datasets
+ python download_datasets.py --tasks all --dataset_types mg --hdf5_types low_dim_sparse
+
+ # download all real robot datasets
+ python download_datasets.py --tasks real
+"""
+import os
+import argparse
+
+import robomimic
+import robomimic.utils.file_utils as FileUtils
+from robomimic import DATASET_REGISTRY
+
+ALL_TASKS = ["lift", "can", "square", "transport", "tool_hang", "lift_real", "can_real", "tool_hang_real"]
+ALL_DATASET_TYPES = ["ph", "mh", "mg", "paired"]
+ALL_HDF5_TYPES = ["raw", "low_dim", "image", "low_dim_sparse", "low_dim_dense", "image_sparse", "image_dense"]
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # directory to download datasets to
+ parser.add_argument(
+ "--download_dir",
+ type=str,
+ default=None,
+ help="Base download directory. Created if it doesn't exist. Defaults to datasets folder in repository.",
+ )
+
+ # tasks to download datasets for
+ parser.add_argument(
+ "--tasks",
+ type=str,
+ nargs='+',
+ default=["lift"],
+ help="Tasks to download datasets for. Defaults to lift task. Pass 'all' to download all tasks (sim + real)\
+ 'sim' to download all sim tasks, 'real' to download all real tasks, or directly specify the list of\
+ tasks.",
+ )
+
+ # dataset types to download datasets for
+ parser.add_argument(
+ "--dataset_types",
+ type=str,
+ nargs='+',
+ default=["ph"],
+ help="Dataset types to download datasets for (e.g. ph, mh, mg). Defaults to ph. Pass 'all' to download \
+ datasets for all available dataset types per task, or directly specify the list of dataset types.",
+ )
+
+ # hdf5 types to download datasets for
+ parser.add_argument(
+ "--hdf5_types",
+ type=str,
+ nargs='+',
+ default=["low_dim"],
+ help="hdf5 types to download datasets for (e.g. raw, low_dim, image). Defaults to raw. Pass 'all' \
+ to download datasets for all available hdf5 types per task and dataset, or directly specify the list\
+ of hdf5 types.",
+ )
+
+ # dry run - don't actually download datasets, but print which datasets would be downloaded
+ parser.add_argument(
+ "--dry_run",
+ action='store_true',
+ help="set this flag to do a dry run to only print which datasets would be downloaded"
+ )
+
+ args = parser.parse_args()
+
+ # set default base directory for downloads
+ default_base_dir = args.download_dir
+ if default_base_dir is None:
+ default_base_dir = os.path.join(robomimic.__path__[0], "../datasets")
+
+ # load args
+ download_tasks = args.tasks
+ if "all" in download_tasks:
+ assert len(download_tasks) == 1, "all should be only tasks argument but got: {}".format(args.tasks)
+ download_tasks = ALL_TASKS
+ elif "sim" in download_tasks:
+ assert len(download_tasks) == 1, "sim should be only tasks argument but got: {}".format(args.tasks)
+ download_tasks = [task for task in ALL_TASKS if "real" not in task]
+ elif "real" in download_tasks:
+ assert len(download_tasks) == 1, "real should be only tasks argument but got: {}".format(args.tasks)
+ download_tasks = [task for task in ALL_TASKS if "real" in task]
+
+ download_dataset_types = args.dataset_types
+ if "all" in download_dataset_types:
+ assert len(download_dataset_types) == 1, "all should be only dataset_types argument but got: {}".format(args.dataset_types)
+ download_dataset_types = ALL_DATASET_TYPES
+
+ download_hdf5_types = args.hdf5_types
+ if "all" in download_hdf5_types:
+ assert len(download_hdf5_types) == 1, "all should be only hdf5_types argument but got: {}".format(args.hdf5_types)
+ download_hdf5_types = ALL_HDF5_TYPES
+
+ # download requested datasets
+ for task in DATASET_REGISTRY:
+ if task in download_tasks:
+ for dataset_type in DATASET_REGISTRY[task]:
+ if dataset_type in download_dataset_types:
+ for hdf5_type in DATASET_REGISTRY[task][dataset_type]:
+ if hdf5_type in download_hdf5_types:
+ download_dir = os.path.abspath(os.path.join(default_base_dir, task, dataset_type))
+ print("\nDownloading dataset:\n task: {}\n dataset type: {}\n hdf5 type: {}\n download path: {}"
+ .format(task, dataset_type, hdf5_type, download_dir))
+ url = DATASET_REGISTRY[task][dataset_type][hdf5_type]["url"]
+ if url is None:
+ print(
+ "Skipping {}-{}-{}, no url for dataset exists.".format(task, dataset_type, hdf5_type)
+ + " Create this dataset locally by running the appropriate command from robomimic/scripts/extract_obs_from_raw_datasets.sh."
+ )
+ continue
+ if args.dry_run:
+ print("\ndry run: skip download")
+ else:
+ # Make sure path exists and create if it doesn't
+ os.makedirs(download_dir, exist_ok=True)
+ FileUtils.download_url(
+ url=DATASET_REGISTRY[task][dataset_type][hdf5_type]["url"],
+ download_dir=download_dir,
+ )
+ print("")
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/download_momart_datasets.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/download_momart_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..affecf11b525f39aaae47095bb85c6086a955a70
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/download_momart_datasets.py
@@ -0,0 +1,161 @@
+"""
+Script to download datasets used in MoMaRT paper (https://arxiv.org/abs/2112.05251). By default, all
+datasets will be stored at robomimic/datasets, unless the @download_dir
+argument is supplied. We recommend using the default, as most examples that
+use these datasets assume that they can be found there.
+
+The @tasks and @dataset_types arguments can all be supplied
+to choose which datasets to download.
+
+Args:
+ download_dir (str): Base download directory. Created if it doesn't exist.
+ Defaults to datasets folder in repository - only pass in if you would
+ like to override the location.
+
+ tasks (list): Tasks to download datasets for. Defaults to table_setup_from_dishwasher task. Pass 'all' to
+ download all tasks - 5 total:
+ - table_setup_from_dishwasher
+ - table_setup_from_dresser
+ - table_cleanup_to_dishwasher
+ - table_cleanup_to_sink
+ - unload_dishwasher
+
+ dataset_types (list): Dataset types to download datasets for (expert, suboptimal, generalize, sample).
+ Defaults to expert. Pass 'all' to download datasets for all available dataset
+ types per task, or directly specify the list of dataset types.
+ NOTE: Because these datasets are huge, we will always print out a warning
+ that a user must respond yes to to acknowledge the data size (can be up to >100G for all tasks of a single type)
+
+Example usage:
+
+ # default behavior - just download expert table_setup_from_dishwasher dataset
+ python download_momart_datasets.py
+
+ # download expert datasets for all tasks
+ # (do a dry run first to see which datasets would be downloaded)
+ python download_momart_datasets.py --tasks all --dataset_types expert --dry_run
+ python download_momart_datasets.py --tasks all --dataset_types expert low_dim
+
+ # download all expert and suboptimal datasets for the table_setup_from_dishwasher and table_cleanup_to_dishwasher tasks
+ python download_datasets.py --tasks table_setup_from_dishwasher table_cleanup_to_dishwasher --dataset_types expert suboptimal
+
+ # download the sample datasets
+ python download_datasets.py --tasks all --dataset_types sample
+
+ # download all datasets
+ python download_datasets.py --tasks all --dataset_types all
+"""
+import os
+import argparse
+
+import robomimic
+import robomimic.utils.file_utils as FileUtils
+from robomimic import MOMART_DATASET_REGISTRY
+
+ALL_TASKS = [
+ "table_setup_from_dishwasher",
+ "table_setup_from_dresser",
+ "table_cleanup_to_dishwasher",
+ "table_cleanup_to_sink",
+ "unload_dishwasher",
+]
+ALL_DATASET_TYPES = [
+ "expert",
+ "suboptimal",
+ "generalize",
+ "sample",
+]
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # directory to download datasets to
+ parser.add_argument(
+ "--download_dir",
+ type=str,
+ default=None,
+ help="Base download directory. Created if it doesn't exist. Defaults to datasets folder in repository.",
+ )
+
+ # tasks to download datasets for
+ parser.add_argument(
+ "--tasks",
+ type=str,
+ nargs='+',
+ default=["table_setup_from_dishwasher"],
+ help="Tasks to download datasets for. Defaults to table_setup_from_dishwasher task. Pass 'all' to download all"
+ f"5 tasks, or directly specify the list of tasks. Options are any of: {ALL_TASKS}",
+ )
+
+ # dataset types to download datasets for
+ parser.add_argument(
+ "--dataset_types",
+ type=str,
+ nargs='+',
+ default=["expert"],
+ help="Dataset types to download datasets for (e.g. expert, suboptimal). Defaults to expert. Pass 'all' to "
+ "download datasets for all available dataset types per task, or directly specify the list of dataset "
+ f"types. Options are any of: {ALL_DATASET_TYPES}",
+ )
+
+ # dry run - don't actually download datasets, but print which datasets would be downloaded
+ parser.add_argument(
+ "--dry_run",
+ action='store_true',
+ help="set this flag to do a dry run to only print which datasets would be downloaded"
+ )
+
+ args = parser.parse_args()
+
+ # set default base directory for downloads
+ default_base_dir = args.download_dir
+ if default_base_dir is None:
+ default_base_dir = os.path.join(robomimic.__path__[0], "../datasets")
+
+ # load args
+ download_tasks = args.tasks
+ if "all" in download_tasks:
+ assert len(download_tasks) == 1, "all should be only tasks argument but got: {}".format(args.tasks)
+ download_tasks = ALL_TASKS
+
+ download_dataset_types = args.dataset_types
+ if "all" in download_dataset_types:
+ assert len(download_dataset_types) == 1, "all should be only dataset_types argument but got: {}".format(args.dataset_types)
+ download_dataset_types = ALL_DATASET_TYPES
+
+ # Run sanity check first to warn user if they're about to download a huge amount of data
+ total_size = 0
+ for task in MOMART_DATASET_REGISTRY:
+ if task in download_tasks:
+ for dataset_type in MOMART_DATASET_REGISTRY[task]:
+ if dataset_type in download_dataset_types:
+ total_size += MOMART_DATASET_REGISTRY[task][dataset_type]["size"]
+
+ # Verify user acknowledgement if we're not doing a dry run
+ if not args.dry_run:
+ user_response = input(f"Warning: requested datasets will take a total of {total_size}GB. Proceed? y/n\n")
+ assert user_response.lower() in {"yes", "y"}, f"Did not receive confirmation. Aborting download."
+
+ # download requested datasets
+ for task in MOMART_DATASET_REGISTRY:
+ if task in download_tasks:
+ for dataset_type in MOMART_DATASET_REGISTRY[task]:
+ if dataset_type in download_dataset_types:
+ dataset_info = MOMART_DATASET_REGISTRY[task][dataset_type]
+ download_dir = os.path.abspath(os.path.join(default_base_dir, task, dataset_type))
+ print(f"\nDownloading dataset:\n"
+ f" task: {task}\n"
+ f" dataset type: {dataset_type}\n"
+ f" dataset size: {dataset_info['size']}GB\n"
+ f" download path: {download_dir}")
+ if args.dry_run:
+ print("\ndry run: skip download")
+ else:
+ # Make sure path exists and create if it doesn't
+ os.makedirs(download_dir, exist_ok=True)
+ FileUtils.download_url(
+ url=dataset_info["url"],
+ download_dir=download_dir,
+ )
+ print("")
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/extract_obs_from_raw_datasets.sh b/phantom/submodules/phantom-robomimic/robomimic/scripts/extract_obs_from_raw_datasets.sh
new file mode 100644
index 0000000000000000000000000000000000000000..00fc78f8bf08df5339e79c65019db683dfac6e59
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/extract_obs_from_raw_datasets.sh
@@ -0,0 +1,140 @@
+#!/bin/bash
+
+# This script holds the commands that were used to go from raw robosuite demo.hdf5 files
+# to our processed low-dim and image hdf5 files.
+
+BASE_DATASET_DIR="../../datasets"
+echo "Using base dataset directory: $BASE_DATASET_DIR"
+
+
+### NOTE: we use done-mode 0 for MG (dones on task success) ###
+
+
+### mg ###
+
+
+# lift - mg, sparse
+python dataset_states_to_obs.py --done_mode 0 \
+--dataset $BASE_DATASET_DIR/lift/mg/demo_v141.hdf5 \
+--output_name low_dim_sparse_v141.hdf5
+python dataset_states_to_obs.py --done_mode 0 \
+--dataset $BASE_DATASET_DIR/lift/mg/demo_v141.hdf5 \
+--output_name image_sparse_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+# lift - mg, dense
+python dataset_states_to_obs.py --done_mode 0 --shaped \
+--dataset $BASE_DATASET_DIR/lift/mg/demo_v141.hdf5 \
+--output_name low_dim_dense_v141.hdf5
+python dataset_states_to_obs.py --done_mode 0 --shaped \
+--dataset $BASE_DATASET_DIR/lift/mg/demo_v141.hdf5 \
+--output_name image_dense_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+# can - mg, sparse
+python dataset_states_to_obs.py --done_mode 0 \
+--dataset $BASE_DATASET_DIR/can/mg/demo_v141.hdf5 \
+--output_name low_dim_sparse_v141.hdf5
+python dataset_states_to_obs.py --done_mode 0 \
+--dataset $BASE_DATASET_DIR/can/mg/demo_v141.hdf5 \
+--output_name image_sparse_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+# can - mg, dense
+python dataset_states_to_obs.py --done_mode 0 --shaped \
+--dataset $BASE_DATASET_DIR/can/mg/demo_v141.hdf5 \
+--output_name low_dim_dense_v141.hdf5
+python dataset_states_to_obs.py --done_mode 0 --shaped \
+--dataset $BASE_DATASET_DIR/can/mg/demo_v141.hdf5 \
+--output_name image_dense_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+
+### NOTE: we use done-mode 2 for PH / MH (dones on task success and end of trajectory) ###
+
+
+### ph ###
+
+
+# lift - ph
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/lift/ph/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/lift/ph/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+# can - ph
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/can/ph/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/can/ph/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+# square - ph
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/square/ph/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/square/ph/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+# transport - ph
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/transport/ph/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/transport/ph/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names shouldercamera0 shouldercamera1 robot0_eye_in_hand robot1_eye_in_hand --camera_height 84 --camera_width 84
+
+# tool hang - ph
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/tool_hang/ph/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/tool_hang/ph/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names sideview robot0_eye_in_hand --camera_height 240 --camera_width 240
+
+
+### mh ###
+
+
+# lift - mh
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/lift/mh/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/lift/mh/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+# can - mh
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/can/mh/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/can/mh/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+# square - mh
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/square/mh/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/square/mh/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
+
+# transport - mh
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/transport/mh/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/transport/mh/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names shouldercamera0 shouldercamera1 robot0_eye_in_hand robot1_eye_in_hand --camera_height 84 --camera_width 84
+
+
+### can-paired ###
+
+
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/can/paired/demo_v141.hdf5 \
+--output_name low_dim_v141.hdf5
+python dataset_states_to_obs.py --done_mode 2 \
+--dataset $BASE_DATASET_DIR/can/paired/demo_v141.hdf5 \
+--output_name image_v141.hdf5 --camera_names agentview robot0_eye_in_hand --camera_height 84 --camera_width 84
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/generate_config_templates.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/generate_config_templates.py
new file mode 100644
index 0000000000000000000000000000000000000000..56e1d8710c124cd418850bf25e016873ed88c49d
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/generate_config_templates.py
@@ -0,0 +1,28 @@
+"""
+Helpful script to generate example config files for each algorithm. These should be re-generated
+when new config options are added, or when default settings in the config classes are modified.
+"""
+import os
+import json
+
+import robomimic
+from robomimic.config import get_all_registered_configs
+
+
+def main():
+ # store template config jsons in this directory
+ target_dir = os.path.join(robomimic.__path__[0], "exps/templates/")
+
+ # iterate through registered algorithm config classes
+ all_configs = get_all_registered_configs()
+ for algo_name in all_configs:
+ # make config class for this algorithm
+ c = all_configs[algo_name]()
+ assert algo_name == c.algo_name
+ # dump to json
+ json_path = os.path.join(target_dir, "{}.json".format(algo_name))
+ c.dump(filename=json_path)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/generate_paper_configs.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/generate_paper_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..52ed7d5b15a25def7da7a02c7c0e135772f269a0
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/generate_paper_configs.py
@@ -0,0 +1,1369 @@
+"""
+Helper script to generate jsons for reproducing paper experiments.
+
+Args:
+ config_dir (str): Directory where generated configs will be placed.
+ Defaults to 'paper' subfolder in exps folder of repository
+
+ dataset_dir (str): Base dataset directory where released datasets can be
+ found on disk. Defaults to datasets folder in repository.
+
+ output_dir (str): Base output directory for all training runs that will be
+ written to generated configs.
+
+Example usage:
+ # Assume datasets alredy exist in robomimic/../datasets folder. Configs will be generated under robomimic/exps/paper
+ python generate_paper_configs.py --output_dir /tmp/experiment_results
+
+ # Specify where datasets exist, and specify where configs should be generated.
+ python generate_paper_configs.py --config_dir /tmp/configs --dataset_dir /tmp/datasets --output_dir /tmp/experiment_results
+"""
+import os
+import argparse
+import robomimic
+from robomimic import DATASET_REGISTRY
+from robomimic.config import Config, BCConfig, BCQConfig, CQLConfig, HBCConfig, IRISConfig, config_factory
+
+
+def modify_config_for_default_low_dim_exp(config):
+ """
+ Modifies a Config object with experiment, training, and observation settings that
+ were used across all low-dimensional experiments by default.
+
+ Args:
+ config (Config instance): config to modify
+ """
+
+ with config.experiment.values_unlocked():
+ # save model during every evaluation (every 50 epochs)
+ config.experiment.save.enabled = True
+ config.experiment.save.every_n_epochs = 50
+
+ # every epoch is 100 gradient steps, and validation epoch is 10 gradient steps
+ config.experiment.epoch_every_n_steps = 100
+ config.experiment.validation_epoch_every_n_steps = 10
+
+ # do 50 evaluation rollouts every 50 epochs
+ # NOTE: horizon will generally get set depending on the task and dataset type
+ config.experiment.rollout.enabled = True
+ config.experiment.rollout.n = 50
+ config.experiment.rollout.horizon = 400
+ config.experiment.rollout.rate = 50
+ config.experiment.rollout.warmstart = 0
+ config.experiment.rollout.terminate_on_success = True
+
+ with config.train.values_unlocked():
+ # assume entire dataset can fit in memory
+ config.train.num_data_workers = 0
+ config.train.hdf5_cache_mode = "all"
+
+ # batch size 100 and 2000 training epochs
+ config.train.batch_size = 100
+ config.train.num_epochs = 2000
+
+ with config.observation.values_unlocked():
+ # default observation is eef pose, gripper finger position, and object information,
+ # all of which are low-dim.
+ default_low_dim_obs = [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "object",
+ ]
+ # handle hierarchical observation configs
+ if config.algo_name == "hbc":
+ configs_to_set = [
+ config.observation.actor.modalities.obs,
+ config.observation.planner.modalities.obs,
+ config.observation.planner.modalities.subgoal,
+ ]
+ elif config.algo_name == "iris":
+ configs_to_set = [
+ config.observation.actor.modalities.obs,
+ config.observation.value_planner.planner.modalities.obs,
+ config.observation.value_planner.planner.modalities.subgoal,
+ config.observation.value_planner.value.modalities.obs,
+ ]
+ else:
+ configs_to_set = [config.observation.modalities.obs]
+ # set all observations / subgoals to use the correct low-dim modalities
+ for cfg in configs_to_set:
+ cfg.low_dim = list(default_low_dim_obs)
+ cfg.rgb = []
+
+ return config
+
+
+def modify_config_for_default_image_exp(config):
+ """
+ Modifies a Config object with experiment, training, and observation settings that
+ were used across all image experiments by default.
+
+ Args:
+ config (Config instance): config to modify
+ """
+ assert config.algo_name not in ["hbc", "iris"], "no image training for HBC and IRIS"
+
+ with config.experiment.values_unlocked():
+ # save model during every evaluation (every 20 epochs)
+ config.experiment.save.enabled = True
+ config.experiment.save.every_n_epochs = 20
+
+ # every epoch is 500 gradient steps, and validation epoch is 50 gradient steps
+ config.experiment.epoch_every_n_steps = 500
+ config.experiment.validation_epoch_every_n_steps = 50
+
+ # do 50 evaluation rollouts every 20 epochs
+ # NOTE: horizon will generally get set depending on the task and dataset type
+ config.experiment.rollout.enabled = True
+ config.experiment.rollout.n = 50
+ config.experiment.rollout.horizon = 400
+ config.experiment.rollout.rate = 20
+ config.experiment.rollout.warmstart = 0
+ config.experiment.rollout.terminate_on_success = True
+
+ with config.train.values_unlocked():
+ # only cache low-dim info, and use 2 data workers to increase fetch speed for image obs
+ config.train.num_data_workers = 2
+ config.train.hdf5_cache_mode = "low_dim"
+
+ # batch size 16 and 600 training epochs
+ config.train.batch_size = 16
+ config.train.num_epochs = 600
+
+
+ with config.observation.values_unlocked():
+ # default low-dim observation is eef pose, gripper finger position
+ # default image observation is external camera and wrist camera
+ config.observation.modalities.obs.low_dim = [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ ]
+ config.observation.modalities.obs.rgb = [
+ "agentview_image",
+ "robot0_eye_in_hand_image",
+ ]
+ config.observation.modalities.goal.low_dim = []
+ config.observation.modalities.goal.rgb = []
+
+ # default image encoder architecture is ResNet with spatial softmax
+ config.observation.encoder.rgb.core_class = "VisualCore"
+ config.observation.encoder.rgb.core_kwargs.feature_dimension = 64
+ config.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv' # ResNet backbone for image observations (unused if no image observations)
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False # kwargs for visual core
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False
+ config.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax" # Alternate options are "SpatialMeanPool" or None (no pooling)
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0
+
+ # observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization
+ config.observation.encoder.rgb.obs_randomizer_class = "CropRandomizer"
+
+ # kwargs for observation randomizers (for the CropRandomizer, this is size and number of crops)
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.num_crops = 1
+ config.observation.encoder.rgb.obs_randomizer_kwargs.pos_enc = False
+
+ return config
+
+
+def modify_config_for_dataset(config, task_name, dataset_type, hdf5_type, base_dataset_dir, filter_key=None):
+ """
+ Modifies a Config object with experiment, training, and observation settings to
+ correspond to experiment settings for the dataset collected on @task_name with
+ dataset source @dataset_type (e.g. ph, mh, mg), and hdf5 type @hdf5_type (e.g. low_dim
+ or image).
+
+ Args:
+ config (Config instance): config to modify
+
+ task_name (str): identify task that dataset was collected on
+
+ dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg).
+
+ hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image).
+
+ base_dataset_dir (str): path to directory where datasets are on disk.
+ Directory structure is expected to be consistent with the output
+ of @make_dataset_dirs in the download_datasets.py script.
+
+ filter_key (str): if not None, use the provided filter key to select a subset of the
+ provided dataset
+ """
+ assert task_name in DATASET_REGISTRY, \
+ "task {} not found in dataset registry!".format(task_name)
+ assert dataset_type in DATASET_REGISTRY[task_name], \
+ "dataset type {} not found for task {} in dataset registry!".format(dataset_type, task_name)
+ assert hdf5_type in DATASET_REGISTRY[task_name][dataset_type], \
+ "hdf5 type {} not found for dataset type {} and task {} in dataset registry!".format(hdf5_type, dataset_type, task_name)
+
+ is_real_dataset = "real" in task_name
+ if is_real_dataset:
+ assert config.algo_name == "bc", "we only ran BC-RNN on real robot"
+ else:
+ assert hdf5_type != "raw", "cannot train on raw demonstrations"
+
+ with config.experiment.values_unlocked():
+
+ # look up rollout evaluation horizon in registry and set it
+ config.experiment.rollout.horizon = DATASET_REGISTRY[task_name][dataset_type][hdf5_type]["horizon"]
+
+ if dataset_type == "mg":
+ # machine-generated datasets did not use validation
+ config.experiment.validate = False
+ else:
+ # all other datasets used validation
+ config.experiment.validate = True
+
+ if is_real_dataset:
+ # no evaluation rollouts for real robot training
+ config.experiment.rollout.enabled = False
+
+ with config.train.values_unlocked():
+ # set dataset path and possibly filter keys
+ url = DATASET_REGISTRY[task_name][dataset_type][hdf5_type]["url"]
+ if url is None:
+ # infer file_name
+ if task_name in ["lift", "can", "square", "tool_hang", "transport"]:
+ file_name = "{}_v141.hdf5".format(hdf5_type)
+ elif task_name in ["lift_real", "can_real", "tool_hang_real"]:
+ file_name = "{}.hdf5".format(hdf5_type)
+ else:
+ raise ValueError("Unknown dataset type")
+ else:
+ file_name = url.split("/")[-1]
+ config.train.data = os.path.join(base_dataset_dir, task_name, dataset_type, file_name)
+ config.train.hdf5_filter_key = None if filter_key is None else filter_key
+ config.train.hdf5_validation_filter_key = None
+ if config.experiment.validate:
+ # set train and valid keys for validation
+ config.train.hdf5_filter_key = "train" if filter_key is None else "{}_train".format(filter_key)
+ config.train.hdf5_validation_filter_key = "valid" if filter_key is None else "{}_valid".format(filter_key)
+
+ with config.observation.values_unlocked():
+ # maybe modify observation names and randomization sizes (since image size might be different)
+
+ if is_real_dataset:
+ # modify observation names for real robot datasets
+ config.observation.modalities.obs.low_dim = [
+ "ee_pose",
+ "gripper_position",
+ ]
+
+ if task_name == "tool_hang_real":
+ # side and wrist camera
+ config.observation.modalities.obs.rgb = [
+ "image_side",
+ "image_wrist",
+ ]
+ # 240x240 images -> crops should be 216x216
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 216
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 216
+ else:
+ # front and wrist camera
+ config.observation.modalities.obs.rgb = [
+ "image",
+ "image_wrist",
+ ]
+ # 120x120 images -> crops should be 108x108
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 108
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 108
+
+ elif hdf5_type in ["image", "image_sparse", "image_dense"]:
+ if task_name == "transport":
+ # robot proprioception per arm
+ config.observation.modalities.obs.low_dim = [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "robot1_eef_pos",
+ "robot1_eef_quat",
+ "robot1_gripper_qpos",
+ ]
+
+ # shoulder and wrist cameras per arm
+ config.observation.modalities.obs.rgb = [
+ "shouldercamera0_image",
+ "robot0_eye_in_hand_image",
+ "shouldercamera1_image",
+ "robot1_eye_in_hand_image",
+ ]
+ elif task_name == "tool_hang":
+ # side and wrist camera
+ config.observation.modalities.obs.rgb = [
+ "sideview_image",
+ "robot0_eye_in_hand_image",
+ ]
+ # 240x240 images -> crops should be 216x216
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 216
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 216
+
+ elif hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ if task_name == "transport":
+ # robot proprioception per arm
+ default_low_dim_obs = [
+ "robot0_eef_pos",
+ "robot0_eef_quat",
+ "robot0_gripper_qpos",
+ "robot1_eef_pos",
+ "robot1_eef_quat",
+ "robot1_gripper_qpos",
+ "object",
+ ]
+ # handle hierarchical observation configs
+ if config.algo_name == "hbc":
+ configs_to_set = [
+ config.observation.actor.modalities.obs,
+ config.observation.planner.modalities.obs,
+ config.observation.planner.modalities.subgoal,
+ ]
+ elif config.algo_name == "iris":
+ configs_to_set = [
+ config.observation.actor.modalities.obs,
+ config.observation.value_planner.planner.modalities.obs,
+ config.observation.value_planner.planner.modalities.subgoal,
+ config.observation.value_planner.value.modalities.obs,
+ ]
+ else:
+ configs_to_set = [config.observation.modalities.obs]
+ # set all observations / subgoals to use the correct low-dim modalities
+ for obs_key_config in configs_to_set:
+ obs_key_config.low_dim = list(default_low_dim_obs)
+ obs_key_config.rgb = []
+
+ return config
+
+
+def modify_bc_config_for_dataset(config, task_name, dataset_type, hdf5_type):
+ """
+ Modifies a BCConfig object for training on a particular kind of dataset. This function
+ just sets algorithm hyperparameters in the algo config depending on the kind of
+ dataset.
+
+ Args:
+ config (BCConfig instance): config to modify
+
+ task_name (str): identify task that dataset was collected on. Only used to distinguish
+ between simulation and real-world, for an assert statement
+
+ dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).
+
+ hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image).
+ """
+ assert isinstance(config, BCConfig), "must be BCConfig"
+ assert config.algo_name == "bc", "must be BCConfig"
+ assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
+ is_real_dataset = "real" in task_name
+ if not is_real_dataset:
+ assert hdf5_type != "raw", "cannot train on raw demonstrations"
+
+ with config.algo.values_unlocked():
+ # base parameters that may get modified
+ config.algo.optim_params.policy.learning_rate.initial = 1e-4 # learning rate 1e-4
+ config.algo.actor_layer_dims = (1024, 1024) # MLP size (1024, 1024)
+ config.algo.gmm.enabled = True # enable GMM
+
+ if dataset_type == "mg":
+ # machine-generated datasets don't use GMM
+ config.algo.gmm.enabled = False # disable GMM
+ if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ # low-dim mg uses LR 1e-3
+ config.algo.optim_params.policy.learning_rate.initial = 1e-3 # learning rate 1e-3
+
+ return config
+
+
+def modify_bc_rnn_config_for_dataset(config, task_name, dataset_type, hdf5_type):
+ """
+ Modifies a BCConfig object for training on a particular kind of dataset. This function
+ just sets algorithm hyperparameters in the algo config depending on the kind of
+ dataset.
+
+ Args:
+ config (BCConfig instance): config to modify
+
+ task_name (str): identify task that dataset was collected on. Only used to distinguish
+ between simulation and real-world, for an assert statement
+
+ dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).
+
+ hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image).
+ """
+ assert isinstance(config, BCConfig), "must be BCConfig"
+ assert config.algo_name == "bc", "must be BCConfig"
+ assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
+ is_real_dataset = "real" in task_name
+ if not is_real_dataset:
+ assert hdf5_type != "raw", "cannot train on raw demonstrations"
+
+ with config.train.values_unlocked():
+ # make sure RNN is enabled with sequence length 10
+ config.train.seq_length = 10
+
+ with config.algo.values_unlocked():
+ # make sure RNN is enabled with sequence length 10
+ config.algo.rnn.enabled = True
+ config.algo.rnn.horizon = 10
+
+ # base parameters that may get modified
+ config.algo.optim_params.policy.learning_rate.initial = 1e-4 # learning rate 1e-4
+ config.algo.actor_layer_dims = () # no MLP layers between rnn layer and output
+ config.algo.gmm.enabled = True # enable GMM
+ config.algo.rnn.hidden_dim = 400 # rnn dim 400
+
+ if dataset_type == "mg":
+ # update hyperparams for machine-generated datasets
+ config.algo.gmm.enabled = False # disable GMM
+ if hdf5_type not in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ # image datasets use RNN dim 1000
+ config.algo.rnn.hidden_dim = 1000 # rnn dim 1000
+ else:
+ # update hyperparams for all other dataset types (ph, mh, paired)
+ if hdf5_type not in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ # image datasets use RNN dim 1000
+ config.algo.rnn.hidden_dim = 1000 # rnn dim 1000
+
+ return config
+
+
+def modify_bcq_config_for_dataset(config, task_name, dataset_type, hdf5_type):
+ """
+ Modifies a BCQConfig object for training on a particular kind of dataset. This function
+ just sets algorithm hyperparameters in the algo config depending on the kind of
+ dataset.
+
+ Args:
+ config (BCQConfig instance): config to modify
+
+ task_name (str): identify task that dataset was collected on. Only used to distinguish
+ between simulation and real-world, for an assert statement
+
+ dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).
+
+ hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image).
+ """
+ assert isinstance(config, BCQConfig), "must be BCQConfig"
+ assert config.algo_name == "bcq", "must be BCQConfig"
+ assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
+ is_real_dataset = "real" in task_name
+ assert not is_real_dataset, "we only ran BC-RNN on real robot"
+ if not is_real_dataset:
+ assert hdf5_type != "raw", "cannot train on raw demonstrations"
+
+ with config.algo.values_unlocked():
+ # base parameters that may get modified further
+ config.algo.optim_params.critic.learning_rate.initial = 1e-4 # all learning rates 1e-3
+ config.algo.optim_params.action_sampler.learning_rate.initial = 1e-4
+ config.algo.optim_params.actor.learning_rate.initial = 1e-3
+ config.algo.actor.enabled = False # disable actor by default
+ config.algo.action_sampler.vae.enabled = True # use VAE action sampler
+ config.algo.action_sampler.gmm.enabled = False
+ config.algo.action_sampler.vae.kl_weight = 0.05 # beta 0.05 for VAE
+ config.algo.action_sampler.vae.latent_dim = 14 # latent dim 14
+ config.algo.action_sampler.vae.prior.learn = False # N(0, 1) prior
+ config.algo.critic.layer_dims = (300, 400) # all MLP sizes at (300, 400)
+ config.algo.action_sampler.vae.encoder_layer_dims = (300, 400)
+ config.algo.action_sampler.vae.decoder_layer_dims = (300, 400)
+ config.algo.actor.layer_dims = (300, 400)
+ config.algo.target_tau = 5e-4 # tau 5e-4
+ config.algo.discount = 0.99 # discount 0.99
+ config.algo.critic.num_action_samples = 10 # number of action sampler samples at train and test
+ config.algo.critic.num_action_samples_rollout = 100
+
+ if dataset_type == "mg":
+ # update hyperparams for machine-generated datasets
+ config.algo.optim_params.critic.learning_rate.initial = 1e-3 # all learning rates 1e-3
+ config.algo.optim_params.action_sampler.learning_rate.initial = 1e-3
+ config.algo.optim_params.actor.learning_rate.initial = 1e-3
+ config.algo.action_sampler.vae.kl_weight = 0.5 # beta 0.5 for VAE
+ config.algo.target_tau = 5e-3 # tau 5e-3
+
+ if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ # enable actor only on low-dim
+ config.algo.actor.enabled = True
+ else:
+ # make some modifications where needed for human datasets
+ if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ if dataset_type in ["mh", "paired"]:
+ # low-dim, MH had higher layer sizes
+ config.algo.critic.layer_dims = (1024, 1024)
+ config.algo.action_sampler.vae.encoder_layer_dims = (1024, 1024)
+ config.algo.action_sampler.vae.decoder_layer_dims = (1024, 1024)
+ config.algo.action_sampler.vae.prior_layer_dims = (1024, 1024)
+
+ config.algo.action_sampler.vae.kl_weight = 0.5
+
+ # use learned GMM prior for MH dataset
+ config.algo.action_sampler.vae.prior.learn = True
+ config.algo.action_sampler.vae.prior.is_conditioned = True
+ config.algo.action_sampler.vae.prior.use_gmm = True
+ config.algo.action_sampler.vae.prior.gmm_learn_weights = True
+ else:
+ if dataset_type == "ph":
+ # image, PH used higher critic LR of 1e-3
+ config.algo.optim_params.critic.learning_rate.initial = 1e-3
+ # image datasets used bigger VAE
+ config.algo.action_sampler.vae.encoder_layer_dims = (1024, 1024)
+ config.algo.action_sampler.vae.decoder_layer_dims = (1024, 1024)
+ if dataset_type in ["mh", "paired"]:
+ # image, MH also had bigger critic
+ config.algo.critic.layer_dims = (1024, 1024)
+
+ return config
+
+
+def modify_cql_config_for_dataset(config, task_name, dataset_type, hdf5_type):
+ """
+ Modifies a CQLConfig object for training on a particular kind of dataset. This function
+ just sets algorithm hyperparameters in the algo config depending on the kind of
+ dataset.
+
+ Args:
+ config (CQLConfig instance): config to modify
+
+ task_name (str): identify task that dataset was collected on. Only used to distinguish
+ between simulation and real-world, for an assert statement
+
+ dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).
+
+ hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image).
+ """
+ assert isinstance(config, CQLConfig), "must be CQLConfig"
+ assert config.algo_name == "cql", "must be CQLConfig"
+ assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
+ is_real_dataset = "real" in task_name
+ assert not is_real_dataset, "we only ran BC-RNN on real robot"
+ if not is_real_dataset:
+ assert hdf5_type != "raw", "cannot train on raw demonstrations"
+
+ with config.train.values_unlocked():
+ # CQL uses batch size 1024 (for low-dim) and 8 (for image)
+ if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ config.train.batch_size = 1024
+ else:
+ config.train.batch_size = 8
+
+ with config.algo.values_unlocked():
+ # base parameters that may get modified further
+ config.algo.optim_params.critic.learning_rate.initial = 1e-3 # learning rates
+ config.algo.optim_params.actor.learning_rate.initial = 3e-4
+ config.algo.actor.target_entropy = "default" # use automatic entropy tuning to default target value
+ config.algo.critic.deterministic_backup = True # deterministic Q-backup
+ config.algo.critic.target_q_gap = 5.0 # use Lagrange, with threshold 5.0
+ config.algo.critic.min_q_weight = 1.0
+ config.algo.target_tau = 5e-3 # tau 5e-3
+ config.algo.discount = 0.99 # discount 0.99
+ config.algo.critic.layer_dims = (300, 400) # all MLP sizes at (300, 400)
+ config.algo.actor.layer_dims = (300, 400)
+
+ if hdf5_type not in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ # update policy LR to 1e-4 for image runs
+ config.algo.optim_params.actor.learning_rate.initial = 1e-4
+
+ return config
+
+
+def modify_hbc_config_for_dataset(config, task_name, dataset_type, hdf5_type):
+ """
+ Modifies a HBCConfig object for training on a particular kind of dataset. This function
+ just sets algorithm hyperparameters in the algo config depending on the kind of
+ dataset.
+
+ Args:
+ config (HBCConfig instance): config to modify
+
+ task_name (str): identify task that dataset was collected on. Only used to distinguish
+ between simulation and real-world, for an assert statement
+
+ dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).
+
+ hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image).
+ """
+ assert isinstance(config, HBCConfig), "must be HBCConfig"
+ assert config.algo_name == "hbc", "must be HBCConfig"
+ assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
+ assert hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"], "HBC only runs on low-dim"
+ is_real_dataset = "real" in task_name
+ assert not is_real_dataset, "we only ran BC-RNN on real robot"
+
+ with config.algo.values_unlocked():
+ # base parameters that may get modified further
+ config.algo.actor.optim_params.policy.learning_rate.initial = 1e-3 # learning rates
+ config.algo.planner.optim_params.goal_network.learning_rate.initial = 1e-3
+
+ config.algo.planner.vae.enabled = True # goal VAE settings
+ config.algo.planner.vae.kl_weight = 5e-4 # beta 5e-4
+ config.algo.planner.vae.latent_dim = 16 # latent dim 16
+ config.algo.planner.vae.prior.learn = True # learn GMM prior with 10 modes
+ config.algo.planner.vae.prior.is_conditioned = True
+ config.algo.planner.vae.prior.use_gmm = True
+ config.algo.planner.vae.prior.gmm_learn_weights = True
+ config.algo.planner.vae.prior.gmm_num_modes = 10
+ config.algo.planner.vae.encoder_layer_dims = (1024, 1024) # VAE network sizes
+ config.algo.planner.vae.decoder_layer_dims = (1024, 1024)
+ config.algo.planner.vae.prior_layer_dims = (1024, 1024)
+
+ config.algo.actor.rnn.hidden_dim = 400 # actor RNN dim
+ config.algo.actor.actor_layer_dims = () # no MLP layers between rnn layer and output
+
+ if dataset_type == "mg":
+ # update hyperparams for machine-generated datasets
+ config.algo.actor.rnn.hidden_dim = 100
+ config.algo.actor.actor_layer_dims = (1024, 1024)
+
+ return config
+
+
+def modify_iris_config_for_dataset(config, task_name, dataset_type, hdf5_type):
+ """
+ Modifies a IRISConfig object for training on a particular kind of dataset. This function
+ just sets algorithm hyperparameters in the algo config depending on the kind of
+ dataset.
+
+ Args:
+ config (IRISConfig instance): config to modify
+
+ task_name (str): identify task that dataset was collected on. Only used to distinguish
+ between simulation and real-world, for an assert statement
+
+ dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).
+
+ hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image).
+ """
+ assert isinstance(config, IRISConfig), "must be IRISConfig"
+ assert config.algo_name == "iris", "must be IRISConfig"
+ assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
+ assert hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"], "IRIS only runs on low-dim"
+ is_real_dataset = "real" in task_name
+ assert not is_real_dataset, "we only ran BC-RNN on real robot"
+
+ with config.algo.values_unlocked():
+ # base parameters that may get modified further
+ config.algo.actor.optim_params.policy.learning_rate.initial = 1e-3 # learning rates
+ config.algo.value_planner.planner.optim_params.goal_network.learning_rate.initial = 1e-3
+ config.algo.value_planner.value.optim_params.critic.learning_rate.initial = 1e-3
+ config.algo.value_planner.value.optim_params.action_sampler.learning_rate.initial = 1e-4
+
+ config.algo.value_planner.planner.vae.enabled = True # goal VAE settings
+ config.algo.value_planner.planner.vae.kl_weight = 5e-4 # beta 5e-4
+ config.algo.value_planner.planner.vae.latent_dim = 14 # latent dim 14
+ config.algo.value_planner.planner.vae.prior.learn = True # learn GMM prior with 10 modes
+ config.algo.value_planner.planner.vae.prior.is_conditioned = True
+ config.algo.value_planner.planner.vae.prior.use_gmm = True
+ config.algo.value_planner.planner.vae.prior.gmm_learn_weights = True
+ config.algo.value_planner.planner.vae.prior.gmm_num_modes = 10
+ config.algo.value_planner.planner.vae.encoder_layer_dims = (1024, 1024) # VAE network sizes
+ config.algo.value_planner.planner.vae.decoder_layer_dims = (1024, 1024)
+ config.algo.value_planner.planner.vae.prior_layer_dims = (1024, 1024)
+
+ config.algo.value_planner.value.target_tau = 5e-4 # Value tau
+ config.algo.value_planner.value.action_sampler.vae.kl_weight = 0.5 # Value KL
+ config.algo.value_planner.value.action_sampler.vae.latent_dim = 16
+ config.algo.value_planner.value.action_sampler.actor_layer_dims = (300, 400)
+
+ config.algo.actor.rnn.hidden_dim = 400 # actor RNN dim
+ config.algo.actor.actor_layer_dims = () # no MLP layers between rnn layer and output
+
+ if dataset_type in ["mh", "paired"]:
+ # value LR 1e-4, KL weight is 0.05 for multi-human datasets
+ config.algo.value_planner.value.optim_params.critic.learning_rate.initial = 1e-4
+ config.algo.value_planner.value.action_sampler.vae.kl_weight = 0.05
+
+ if dataset_type in ["mg"]:
+ # Enable value actor and set larger target tau
+ config.algo.value_planner.value.actor.enabled = True
+ config.algo.value_planner.value.optim_params.actor.learning_rate.initial = 1e-3
+ config.algo.value_planner.value.target_tau = 5e-3
+
+ return config
+
+
+def generate_experiment_config(
+ base_exp_name,
+ base_config_dir,
+ base_dataset_dir,
+ base_output_dir,
+ algo_name,
+ algo_config_modifier,
+ task_name,
+ dataset_type,
+ hdf5_type,
+ filter_key=None,
+ additional_name=None,
+ additional_config_modifier=None,
+):
+ """
+ Helper function to generate a config for a particular experiment.
+
+ Args:
+ base_exp_name (str): name that identifies this set of experiments
+
+ base_config_dir (str): base directory to place generated configs
+
+ base_dataset_dir (str): path to directory where datasets are on disk.
+ Directory structure is expected to be consistent with the output
+ of @make_dataset_dirs in the download_datasets.py script.
+
+ base_output_dir (str): directory to save training results to. If None, will use the directory
+ from the default algorithm configs.
+
+ algo_name (str): identifies the algorithm - one of ["bc", "bc_rnn", "bcq", "cql", hbc", "iris"]
+
+ algo_config_modifier (function): function to modify config to add algo hyperparameter
+ settings, given the task, dataset, and hdf5 types.
+
+ task_name (str): identify task that dataset was collected on. Only used to distinguish
+ between simulation and real-world, for an assert statement
+
+ dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).
+
+ hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image).
+
+ filter_key (str): if not None, use the provided filter key to select a subset of the
+ provided dataset
+
+ additional_name (str): if provided, will add this name to the generated experiment name, and
+ the name of the generated config json
+
+ additional_config_modifier (function): if provided, run this last function on the config
+ to make final modifications before generating the json.
+ """
+ if "real" not in task_name:
+ assert hdf5_type != "raw", "cannot train on raw demonstrations"
+
+ # decide whether to use low-dim or image training defaults
+ modifier_for_obs = modify_config_for_default_image_exp
+ if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ modifier_for_obs = modify_config_for_default_low_dim_exp
+
+ algo_config_name = "bc" if algo_name == "bc_rnn" else algo_name
+ config = config_factory(algo_name=algo_config_name)
+ # turn into default config for observation modalities (e.g.: low-dim or rgb)
+ config = modifier_for_obs(config)
+ # add in config based on the dataset
+ config = modify_config_for_dataset(
+ config=config,
+ task_name=task_name,
+ dataset_type=dataset_type,
+ hdf5_type=hdf5_type,
+ base_dataset_dir=base_dataset_dir,
+ filter_key=filter_key,
+ )
+ # add in algo hypers based on dataset
+ config = algo_config_modifier(
+ config=config,
+ task_name=task_name,
+ dataset_type=dataset_type,
+ hdf5_type=hdf5_type,
+ )
+ if additional_config_modifier is not None:
+ # use additional config modifier if provided
+ config = additional_config_modifier(config)
+
+ # account for filter key in experiment naming and directory naming
+ filter_key_str = "_{}".format(filter_key) if filter_key is not None else ""
+ dataset_type_dir = "{}/{}".format(dataset_type, filter_key) if filter_key is not None else dataset_type
+
+ # account for @additional_name
+ additional_name_str = "_{}".format(additional_name) if additional_name is not None else ""
+ json_name = "{}{}".format(algo_name, additional_name_str)
+
+ # set experiment name
+ with config.experiment.values_unlocked():
+ config.experiment.name = "{}_{}_{}_{}{}_{}{}".format(base_exp_name, algo_name, task_name, dataset_type, filter_key_str, hdf5_type, additional_name_str)
+ # set output folder
+ with config.train.values_unlocked():
+ if base_output_dir is None:
+ base_output_dir = config.train.output_dir
+ config.train.output_dir = os.path.join(base_output_dir, base_exp_name, algo_name, task_name, dataset_type_dir, hdf5_type, "trained_models")
+
+ # save config to json file
+ dir_to_save = os.path.join(base_config_dir, base_exp_name, task_name, dataset_type_dir, hdf5_type)
+ os.makedirs(dir_to_save, exist_ok=True)
+ json_path = os.path.join(dir_to_save, "{}.json".format(json_name))
+ config.dump(filename=json_path)
+
+ return config, json_path
+
+
+def generate_core_configs(
+ base_config_dir,
+ base_dataset_dir,
+ base_output_dir,
+ algo_to_config_modifier,
+):
+ """
+ Helper function to generate all configs for core set of experiments.
+
+ Args:
+ base_config_dir (str): base directory to place generated configs
+
+ base_dataset_dir (str): path to directory where datasets are on disk.
+ Directory structure is expected to be consistent with the output
+ of @make_dataset_dirs in the download_datasets.py script.
+
+ base_output_dir (str): directory to save training results to. If None, will use the directory
+ from the default algorithm configs.
+
+ algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs
+ to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
+ """
+ core_json_paths = Config() # use for convenient nested dict
+ for task in DATASET_REGISTRY:
+ for dataset_type in DATASET_REGISTRY[task]:
+ for hdf5_type in DATASET_REGISTRY[task][dataset_type]:
+ # if not real robot dataset, skip raw hdf5
+ is_real_dataset = ("real" in task)
+ if not is_real_dataset and hdf5_type == "raw":
+ continue
+
+ # get list of algorithms to generate configs for, for this hdf5 dataset
+ algos_to_generate = ["bc", "bc_rnn", "bcq", "cql", "hbc", "iris"]
+ if hdf5_type not in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
+ # no hbc or iris for image runs
+ algos_to_generate = algos_to_generate[:-2]
+ if is_real_dataset:
+ # we only ran BC-RNN on real robot
+ algos_to_generate = ["bc_rnn"]
+
+ for algo_name in algos_to_generate:
+
+ # generate config for this experiment
+ config, json_path = generate_experiment_config(
+ base_exp_name="core",
+ base_config_dir=base_config_dir,
+ base_dataset_dir=base_dataset_dir,
+ base_output_dir=base_output_dir,
+ algo_name=algo_name,
+ algo_config_modifier=algo_to_config_modifier[algo_name],
+ task_name=task,
+ dataset_type=dataset_type,
+ hdf5_type=hdf5_type,
+ )
+
+ # save json path into dict
+ core_json_paths[task][dataset_type][hdf5_type][algo_name] = json_path
+
+ return core_json_paths
+
+
+def generate_subopt_configs(
+ base_config_dir,
+ base_dataset_dir,
+ base_output_dir,
+ algo_to_config_modifier,
+):
+ """
+ Helper function to generate all configs for the suboptimal human subsets of the multi-human datasets.
+ Note that while the paper includes the results on the can-paired dataset along with results on these
+ datasets, the configs for runs on the can-paired dataset is in the "core" set of runs.
+
+ Args:
+ base_config_dir (str): base directory to place generated configs
+
+ base_dataset_dir (str): path to directory where datasets are on disk.
+ Directory structure is expected to be consistent with the output
+ of @make_dataset_dirs in the download_datasets.py script.
+
+ base_output_dir (str): directory to save training results to. If None, will use the directory
+ from the default algorithm configs.
+
+ algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs
+ to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
+ """
+ subopt_json_paths = Config() # use for convenient nested dict
+ for task in ["lift", "can", "square", "transport"]:
+ # only generate configs for multi-human data subsets
+ for dataset_type in ["mh"]:
+ # only low-dim / image
+ for hdf5_type in ["low_dim", "image"]:
+
+ # get list of algorithms to generate configs for, for this hdf5 dataset
+ algos_to_generate = ["bc", "bc_rnn", "bcq", "cql", "hbc", "iris"]
+ if hdf5_type == "image":
+ # no hbc or iris for image runs
+ algos_to_generate = algos_to_generate[:-2]
+
+ for algo_name in algos_to_generate:
+
+ for fk in ["worse", "okay", "better", "worse_okay", "worse_better", "okay_better"]:
+
+ # generate config for this experiment
+ config, json_path = generate_experiment_config(
+ base_exp_name="subopt",
+ base_config_dir=base_config_dir,
+ base_dataset_dir=base_dataset_dir,
+ base_output_dir=base_output_dir,
+ algo_name=algo_name,
+ algo_config_modifier=algo_to_config_modifier[algo_name],
+ task_name=task,
+ dataset_type=dataset_type,
+ hdf5_type=hdf5_type,
+ filter_key=fk,
+ )
+
+ # save json path into dict
+ dataset_type_dir = "{}/{}".format(dataset_type, fk)
+ subopt_json_paths[task][dataset_type_dir][hdf5_type][algo_name] = json_path
+
+ return subopt_json_paths
+
+
+def generate_dataset_size_configs(
+ base_config_dir,
+ base_dataset_dir,
+ base_output_dir,
+ algo_to_config_modifier,
+):
+ """
+ Helper function to generate all configs for the dataset size ablation experiments, where BC-RNN models
+ were trained on 20% and 50% dataset sizes.
+
+ Args:
+ base_config_dir (str): base directory to place generated configs
+
+ base_dataset_dir (str): path to directory where datasets are on disk.
+ Directory structure is expected to be consistent with the output
+ of @make_dataset_dirs in the download_datasets.py script.
+
+ base_output_dir (str): directory to save training results to. If None, will use the directory
+ from the default algorithm configs.
+
+ algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs
+ to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
+ """
+ size_ablation_json_paths = Config() # use for convenient nested dict
+ for task in ["lift", "can", "square", "transport"]:
+ for dataset_type in ["ph", "mh"]:
+ for hdf5_type in ["low_dim", "image"]:
+
+ # only bc-rnn
+ algo_name = "bc_rnn"
+ for fk in ["20_percent", "50_percent"]:
+
+ # generate config for this experiment
+ config, json_path = generate_experiment_config(
+ base_exp_name="dataset_size",
+ base_config_dir=base_config_dir,
+ base_dataset_dir=base_dataset_dir,
+ base_output_dir=base_output_dir,
+ algo_name=algo_name,
+ algo_config_modifier=algo_to_config_modifier[algo_name],
+ task_name=task,
+ dataset_type=dataset_type,
+ hdf5_type=hdf5_type,
+ filter_key=fk,
+ )
+
+ # save json path into dict
+ dataset_type_dir = "{}/{}".format(dataset_type, fk)
+ size_ablation_json_paths[task][dataset_type_dir][hdf5_type][algo_name] = json_path
+
+ return size_ablation_json_paths
+
+
+def generate_obs_ablation_configs(
+ base_config_dir,
+ base_dataset_dir,
+ base_output_dir,
+ algo_to_config_modifier,
+):
+ """
+ Helper function to generate all configs for the observation ablation experiments, where BC and BC-RNN models
+ were trained on different versions of low-dim and image observations.
+
+ Args:
+ base_config_dir (str): base directory to place generated configs
+
+ base_dataset_dir (str): path to directory where datasets are on disk.
+ Directory structure is expected to be consistent with the output
+ of @make_dataset_dirs in the download_datasets.py script.
+
+ base_output_dir (str): directory to save training results to. If None, will use the directory
+ from the default algorithm configs.
+
+ algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs
+ to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
+ """
+
+ # observation config modifiers for these experiments
+ def add_eef_vel(config):
+ with config.observation.values_unlocked():
+ old_low_dim_mods = list(config.observation.modalities.obs.low_dim)
+ old_low_dim_mods.extend(["robot0_eef_vel_lin", "robot0_eef_vel_ang", "robot0_gripper_qvel"])
+ if "robot1_eef_pos" in old_low_dim_mods:
+ old_low_dim_mods.extend(["robot1_eef_vel_lin", "robot1_eef_vel_ang", "robot1_gripper_qvel"])
+ config.observation.modalities.obs.low_dim = old_low_dim_mods
+ return config
+
+ def add_proprio(config):
+ with config.observation.values_unlocked():
+ old_low_dim_mods = list(config.observation.modalities.obs.low_dim)
+ old_low_dim_mods.extend(["robot0_joint_pos_cos", "robot0_joint_pos_sin", "robot0_joint_vel"])
+ if "robot1_eef_pos" in old_low_dim_mods:
+ old_low_dim_mods.extend(["robot1_joint_pos_cos", "robot1_joint_pos_sin", "robot1_joint_vel"])
+ config.observation.modalities.obs.low_dim = old_low_dim_mods
+ return config
+
+ def remove_wrist(config):
+ with config.observation.values_unlocked():
+ old_image_mods = list(config.observation.modalities.obs.rgb)
+ config.observation.modalities.obs.rgb = [m for m in old_image_mods if "eye_in_hand" not in m]
+ return config
+
+ def remove_rand(config):
+ with config.observation.values_unlocked():
+ config.observation.encoder.rgb.obs_randomizer_class = None
+ return config
+
+ obs_ablation_json_paths = Config() # use for convenient nested dict
+ for task in ["square", "transport"]:
+ for dataset_type in ["ph", "mh"]:
+ for hdf5_type in ["low_dim", "image"]:
+
+ # observation modifiers to apply
+ if hdf5_type == "low_dim":
+ obs_modifiers = [add_eef_vel, add_proprio]
+ else:
+ obs_modifiers = [add_eef_vel, add_proprio, remove_wrist, remove_rand]
+
+ # only bc and bc-rnn
+ algos_to_generate = ["bc", "bc_rnn"]
+ for algo_name in algos_to_generate:
+ for obs_modifier in obs_modifiers:
+ # generate config for this experiment
+ config, json_path = generate_experiment_config(
+ base_exp_name="obs_ablation",
+ base_config_dir=base_config_dir,
+ base_dataset_dir=base_dataset_dir,
+ base_output_dir=base_output_dir,
+ algo_name=algo_name,
+ algo_config_modifier=algo_to_config_modifier[algo_name],
+ task_name=task,
+ dataset_type=dataset_type,
+ hdf5_type=hdf5_type,
+ additional_name=obs_modifier.__name__,
+ additional_config_modifier=obs_modifier,
+ )
+
+ # save json path into dict
+ algo_name_str = "{}_{}".format(algo_name, obs_modifier.__name__)
+ obs_ablation_json_paths[task][dataset_type][hdf5_type][algo_name_str] = json_path
+
+ return obs_ablation_json_paths
+
+
+def generate_hyper_ablation_configs(
+ base_config_dir,
+ base_dataset_dir,
+ base_output_dir,
+ algo_to_config_modifier,
+):
+ """
+ Helper function to generate all configs for the hyperparameter sensitivity experiments,
+ where BC-RNN models were trained on different ablations.
+
+ Args:
+ base_config_dir (str): base directory to place generated configs
+
+ base_dataset_dir (str): path to directory where datasets are on disk.
+ Directory structure is expected to be consistent with the output
+ of @make_dataset_dirs in the download_datasets.py script.
+
+ base_output_dir (str): directory to save training results to. If None, will use the directory
+ from the default algorithm configs.
+
+ algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs
+ to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
+ """
+
+ # observation config modifiers for these experiments
+ def change_lr(config):
+ with config.algo.values_unlocked():
+ config.algo.optim_params.policy.learning_rate.initial = 1e-3
+ return config
+
+ def change_gmm(config):
+ with config.algo.values_unlocked():
+ config.algo.gmm.enabled = False
+ return config
+
+ def change_mlp(config):
+ with config.algo.values_unlocked():
+ config.algo.actor_layer_dims = (1024, 1024)
+ return config
+
+ def change_conv(config):
+ with config.observation.values_unlocked():
+ config.observation.encoder.rgb.core_class = 'ShallowConv'
+ config.observation.encoder.rgb.core_kwargs = Config()
+ return config
+
+ def change_rnnd_low_dim(config):
+ with config.algo.values_unlocked():
+ config.algo.rnn.hidden_dim = 100
+ return config
+
+ def change_rnnd_image(config):
+ with config.algo.values_unlocked():
+ config.algo.rnn.hidden_dim = 400
+ return config
+
+ hyper_ablation_json_paths = Config() # use for convenient nested dict
+ for task in ["square", "transport"]:
+ for dataset_type in ["ph", "mh"]:
+ for hdf5_type in ["low_dim", "image"]:
+
+ # observation modifiers to apply
+ if hdf5_type == "low_dim":
+ hyper_modifiers = [change_lr, change_gmm, change_mlp, change_rnnd_low_dim]
+ else:
+ hyper_modifiers = [change_lr, change_gmm, change_conv, change_rnnd_image]
+
+ # only bc and bc-rnn
+ algo_name = "bc_rnn"
+ for hyper_modifier in hyper_modifiers:
+ # generate config for this experiment
+ config, json_path = generate_experiment_config(
+ base_exp_name="hyper_ablation",
+ base_config_dir=base_config_dir,
+ base_dataset_dir=base_dataset_dir,
+ base_output_dir=base_output_dir,
+ algo_name=algo_name,
+ algo_config_modifier=algo_to_config_modifier[algo_name],
+ task_name=task,
+ dataset_type=dataset_type,
+ hdf5_type=hdf5_type,
+ additional_name=hyper_modifier.__name__,
+ additional_config_modifier=hyper_modifier,
+ )
+
+ # save json path into dict
+ algo_name_str = "{}_{}".format(algo_name, hyper_modifier.__name__)
+ hyper_ablation_json_paths[task][dataset_type][hdf5_type][algo_name_str] = json_path
+
+ return hyper_ablation_json_paths
+
+
+def generate_d4rl_configs(
+ base_config_dir,
+ base_dataset_dir,
+ base_output_dir,
+ algo_to_config_modifier,
+):
+ """
+ Helper function to generate all configs for reproducing BCQ, CQL, and TD3-BC runs on some D4RL
+ environments.
+
+ Args:
+ base_config_dir (str): base directory to place generated configs
+
+ base_dataset_dir (str): path to directory where datasets are on disk.
+ Directory structure is expected to be consistent with the output
+ of @make_dataset_dirs in the download_datasets.py script.
+
+ base_output_dir (str): directory to save training results to. If None, will use the directory
+ from the default algorithm configs.
+
+ algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs
+ to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
+ """
+
+ def bcq_algo_config_modifier(config):
+ with config.algo.values_unlocked():
+ # all LRs 1e-3, enable actor
+ config.algo.optim_params.critic.learning_rate.initial = 1e-3
+ config.algo.optim_params.action_sampler.learning_rate.initial = 1e-3
+ config.algo.optim_params.actor.learning_rate.initial = 1e-3
+ config.algo.actor.enabled = True
+ config.algo.action_sampler.vae.kl_weight = 0.5
+ return config
+
+ def cql_algo_config_modifier(config):
+ with config.algo.values_unlocked():
+ # taken from TD3-BC settings described in their paper
+ config.algo.optim_params.critic.learning_rate.initial = 3e-4
+ config.algo.optim_params.actor.learning_rate.initial = 3e-5
+ config.algo.actor.bc_start_steps = 40000 # pre-training steps for actor
+ config.algo.critic.target_q_gap = None # no Lagrange, and fixed weight of 10.0
+ config.algo.critic.cql_weight = 10.0
+ config.algo.critic.min_q_weight = 1.0
+ config.algo.critic.deterministic_backup = True # deterministic backup (no entropy in Q-target)
+ config.algo.actor.layer_dims = (256, 256, 256) # MLP sizes
+ config.algo.critic.layer_dims = (256, 256, 256)
+ return config
+
+ def iql_algo_config_modifier(config):
+ with config.algo.values_unlocked():
+ # taken from IQL settings described in their paper
+ config.algo.target_tau = 0.005
+ config.algo.vf_quantile = 0.7
+ config.algo.adv.beta = 3.0
+ config.algo.optim_params.critic.learning_rate.initial = 3e-4
+ config.algo.optim_params.vf.learning_rate.initial = 3e-4
+ config.algo.optim_params.actor.learning_rate.initial = 3e-4
+ config.algo.actor.layer_dims = (256, 256, 256) # MLP sizes
+ config.algo.critic.layer_dims = (256, 256, 256)
+ return config
+
+ d4rl_tasks = [
+ # "halfcheetah-random-v2",
+ # "hopper-random-v2",
+ # "walker2d-random-v2",
+ "halfcheetah-medium-v2",
+ "hopper-medium-v2",
+ "walker2d-medium-v2",
+ "halfcheetah-expert-v2",
+ "hopper-expert-v2",
+ "walker2d-expert-v2",
+ "halfcheetah-medium-expert-v2",
+ "hopper-medium-expert-v2",
+ "walker2d-medium-expert-v2",
+ # "halfcheetah-medium-replay-v2",
+ # "hopper-medium-replay-v2",
+ # "walker2d-medium-replay-v2",
+ ]
+ d4rl_json_paths = Config() # use for convenient nested dict
+ for task_name in d4rl_tasks:
+ for algo_name in ["bcq", "cql", "td3_bc", "iql"]:
+ config = config_factory(algo_name=algo_name)
+
+ # hack: copy experiment and train sections from td3-bc, since that has defaults for training with D4RL
+ if algo_name != "td3_bc":
+ ref_config = config_factory(algo_name="td3_bc")
+ with config.values_unlocked():
+ config.experiment = ref_config.experiment
+ config.train = ref_config.train
+ config.observation = ref_config.observation
+ config.train.hdf5_normalize_obs = False # only TD3-BC uses observation normalization
+
+ # modify algo section for d4rl defaults
+ if algo_name == "bcq":
+ config = bcq_algo_config_modifier(config)
+ elif algo_name == "cql":
+ config = cql_algo_config_modifier(config)
+ elif algo_name == "iql":
+ config = iql_algo_config_modifier(config)
+
+ # set experiment name
+ with config.experiment.values_unlocked():
+ config.experiment.name = "{}_{}_{}".format("d4rl", algo_name, task_name)
+ # set output folder and dataset
+ with config.train.values_unlocked():
+ if base_output_dir is None:
+ base_output_dir_for_algo = "../{}_trained_models".format(algo_name)
+ else:
+ base_output_dir_for_algo = base_output_dir
+ config.train.output_dir = os.path.join(base_output_dir_for_algo, "d4rl", algo_name, task_name, "trained_models")
+ config.train.data = os.path.join(base_dataset_dir, "d4rl", "converted",
+ "{}.hdf5".format(task_name.replace("-", "_")))
+
+ # save config to json file
+ dir_to_save = os.path.join(base_config_dir, "d4rl", task_name)
+ os.makedirs(dir_to_save, exist_ok=True)
+ json_path = os.path.join(dir_to_save, "{}.json".format(algo_name))
+ config.dump(filename=json_path)
+
+ # save json path into dict
+ d4rl_json_paths[task_name][""][""][algo_name] = json_path
+
+ return d4rl_json_paths
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # Directory where generated configs will be placed
+ parser.add_argument(
+ "--config_dir",
+ type=str,
+ default=None,
+ help="Directory where generated configs will be placed. Defaults to 'paper' subfolder in exps folder of repository",
+ )
+
+ # directory where released datasets are located
+ parser.add_argument(
+ "--dataset_dir",
+ type=str,
+ default=None,
+ help="Base dataset directory for released datasets. Defaults to datasets folder in repository.",
+ )
+
+ # output directory for training runs (will be written to configs)
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default=None,
+ help="Base output directory for all training runs that will be written to generated configs.",
+ )
+
+ args = parser.parse_args()
+
+ # read args
+ generated_configs_base_dir = args.config_dir
+ if generated_configs_base_dir is None:
+ generated_configs_base_dir = os.path.join(robomimic.__path__[0], "exps/paper")
+
+ datasets_base_dir = args.dataset_dir
+ if datasets_base_dir is None:
+ datasets_base_dir = os.path.join(robomimic.__path__[0], "../datasets")
+
+ output_base_dir = args.output_dir
+
+ # algo to modifier
+ algo_to_modifier = dict(
+ bc=modify_bc_config_for_dataset,
+ bc_rnn=modify_bc_rnn_config_for_dataset,
+ bcq=modify_bcq_config_for_dataset,
+ cql=modify_cql_config_for_dataset,
+ hbc=modify_hbc_config_for_dataset,
+ iris=modify_iris_config_for_dataset,
+ )
+
+ # exp name to config generator
+ exp_name_to_generator = dict(
+ core=generate_core_configs,
+ subopt=generate_subopt_configs,
+ dataset_size=generate_dataset_size_configs,
+ obs_ablation=generate_obs_ablation_configs,
+ hyper_ablation=generate_hyper_ablation_configs,
+ d4rl=generate_d4rl_configs,
+ )
+
+ # generate configs for each experiment name
+ config_json_paths = Config() # use for convenient nested dict
+ for exp_name in exp_name_to_generator:
+ config_json_paths[exp_name] = exp_name_to_generator[exp_name](
+ base_config_dir=generated_configs_base_dir,
+ base_dataset_dir=datasets_base_dir,
+ base_output_dir=output_base_dir,
+ algo_to_config_modifier=algo_to_modifier,
+ )
+
+ # write output shell scripts
+ for exp_name in config_json_paths:
+ shell_path = os.path.join(generated_configs_base_dir, "{}.sh".format(exp_name))
+ with open(shell_path, "w") as f:
+ f.write("#!/bin/bash\n\n")
+ f.write("# " + "=" * 10 + exp_name + "=" * 10 + "\n")
+ train_script_loc = os.path.join(robomimic.__path__[0], "scripts/train.py")
+
+ for task in config_json_paths[exp_name]:
+ for dataset_type in config_json_paths[exp_name][task]:
+ for hdf5_type in config_json_paths[exp_name][task][dataset_type]:
+ f.write("\n")
+ f.write("# task: {}\n".format(task))
+ if len(dataset_type) > 0:
+ f.write("# dataset type: {}\n".format(dataset_type))
+ if len(hdf5_type) > 0:
+ f.write("# hdf5 type: {}\n".format(hdf5_type))
+ for algo_name in config_json_paths[exp_name][task][dataset_type][hdf5_type]:
+ # f.write("# {}\n".format(algo_name))
+ exp_json_path = config_json_paths[exp_name][task][dataset_type][hdf5_type][algo_name]
+ cmd = "python {} --config {}\n".format(train_script_loc, exp_json_path)
+ f.write(cmd)
+ f.write("\n")
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/get_dataset_info.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/get_dataset_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..8971bcdb725adbe9a0df0e5e0b77f88b8684d153
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/get_dataset_info.py
@@ -0,0 +1,135 @@
+"""
+Helper script to report dataset information. By default, will print trajectory length statistics,
+the maximum and minimum action element in the dataset, filter keys present, environment
+metadata, and the structure of the first demonstration. If --verbose is passed, it will
+report the exact demo keys under each filter key, and the structure of all demonstrations
+(not just the first one).
+
+Args:
+ dataset (str): path to hdf5 dataset
+
+ filter_key (str): if provided, report statistics on the subset of trajectories
+ in the file that correspond to this filter key
+
+ verbose (bool): if flag is provided, print more details, like the structure of all
+ demonstrations (not just the first one)
+
+Example usage:
+
+ # run script on example hdf5 packaged with repository
+ python get_dataset_info.py --dataset ../../tests/assets/test.hdf5
+
+ # run script only on validation data
+ python get_dataset_info.py --dataset ../../tests/assets/test.hdf5 --filter_key valid
+"""
+import h5py
+import json
+import argparse
+import numpy as np
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="path to hdf5 dataset",
+ )
+ parser.add_argument(
+ "--filter_key",
+ type=str,
+ default=None,
+ help="(optional) if provided, report statistics on the subset of trajectories \
+ in the file that correspond to this filter key",
+ )
+ parser.add_argument(
+ "--verbose",
+ action='store_true',
+ help="verbose output",
+ )
+ args = parser.parse_args()
+
+ # extract demonstration list from file
+ filter_key = args.filter_key
+ all_filter_keys = None
+ f = h5py.File(args.dataset, "r")
+ if filter_key is not None:
+ # use the demonstrations from the filter key instead
+ print("NOTE: using filter key {}".format(filter_key))
+ demos = sorted([elem.decode("utf-8") for elem in np.array(f["mask/{}".format(filter_key)])])
+ else:
+ # use all demonstrations
+ demos = sorted(list(f["data"].keys()))
+
+ # extract filter key information
+ if "mask" in f:
+ all_filter_keys = {}
+ for fk in f["mask"]:
+ fk_demos = sorted([elem.decode("utf-8") for elem in np.array(f["mask/{}".format(fk)])])
+ all_filter_keys[fk] = fk_demos
+
+ # put demonstration list in increasing episode order
+ inds = np.argsort([int(elem[5:]) for elem in demos])
+ demos = [demos[i] for i in inds]
+
+ # extract length of each trajectory in the file
+ traj_lengths = []
+ action_min = np.inf
+ action_max = -np.inf
+ for ep in demos:
+ traj_lengths.append(f["data/{}/actions".format(ep)].shape[0])
+ action_min = min(action_min, np.min(f["data/{}/actions".format(ep)][()]))
+ action_max = max(action_max, np.max(f["data/{}/actions".format(ep)][()]))
+ traj_lengths = np.array(traj_lengths)
+
+ # report statistics on the data
+ print("")
+ print("total transitions: {}".format(np.sum(traj_lengths)))
+ print("total trajectories: {}".format(traj_lengths.shape[0]))
+ print("traj length mean: {}".format(np.mean(traj_lengths)))
+ print("traj length std: {}".format(np.std(traj_lengths)))
+ print("traj length min: {}".format(np.min(traj_lengths)))
+ print("traj length max: {}".format(np.max(traj_lengths)))
+ print("action min: {}".format(action_min))
+ print("action max: {}".format(action_max))
+ print("")
+ print("==== Filter Keys ====")
+ if all_filter_keys is not None:
+ for fk in all_filter_keys:
+ print("filter key {} with {} demos".format(fk, len(all_filter_keys[fk])))
+ else:
+ print("no filter keys")
+ print("")
+ if args.verbose:
+ if all_filter_keys is not None:
+ print("==== Filter Key Contents ====")
+ for fk in all_filter_keys:
+ print("filter_key {} with {} demos: {}".format(fk, len(all_filter_keys[fk]), all_filter_keys[fk]))
+ print("")
+ env_meta = json.loads(f["data"].attrs["env_args"])
+ print("==== Env Meta ====")
+ print(json.dumps(env_meta, indent=4))
+ print("")
+
+ print("==== Dataset Structure ====")
+ for ep in demos:
+ print("episode {} with {} transitions".format(ep, f["data/{}".format(ep)].attrs["num_samples"]))
+ for k in f["data/{}".format(ep)]:
+ if k in ["obs", "next_obs"]:
+ print(" key: {}".format(k))
+ for obs_k in f["data/{}/{}".format(ep, k)]:
+ shape = f["data/{}/{}/{}".format(ep, k, obs_k)].shape
+ dtype = f["data/{}/{}/{}".format(ep, k, obs_k)].dtype
+ print(" observation key {} with shape {} and dtype {}".format(obs_k, shape, dtype))
+ elif isinstance(f["data/{}/{}".format(ep, k)], h5py.Dataset):
+ key_shape = f["data/{}/{}".format(ep, k)].shape
+ print(" key: {} with shape {}".format(k, key_shape))
+
+ if not args.verbose:
+ break
+
+ f.close()
+
+ # maybe display error message
+ print("")
+ if (action_min < -1.) or (action_max > 1.):
+ raise Exception("Dataset should have actions in [-1., 1.] but got bounds [{}, {}]".format(action_min, action_max))
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/give_slack_notification.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/give_slack_notification.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b23d9a4945292eba0018b16063a3a2ae7f3123
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/give_slack_notification.py
@@ -0,0 +1,53 @@
+"""
+Script to send a slack message for notifications on completed training runs.
+Super extra, but gotta love it.
+"""
+
+import os
+import argparse
+import socket
+import ssl as ssl_lib
+import certifi
+import time
+import datetime
+
+import slack_sdk
+from slack_sdk import WebClient
+from slack_sdk.errors import SlackApiError
+
+import robomimic
+import robomimic.macros as Macros
+
+
+def give_slack_notif(msg):
+ # for some reason, we need to explicitly create an SSL context
+ ssl_context = ssl_lib.create_default_context(cafile=certifi.where())
+ client = WebClient(Macros.SLACK_TOKEN, ssl=ssl_context)
+
+ try:
+ response = client.chat_postMessage(
+ channel=Macros.SLACK_USER_ID,
+ text=msg,
+ )
+ except SlackApiError as e:
+ # You will get a SlackApiError if "ok" is False
+ assert e.response["ok"] is False
+ assert e.response["error"] # str like 'invalid_auth', 'channel_not_found'
+ print(f"Got a slack error: {e.response['error']}")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--message",
+ type=str,
+ )
+ args = parser.parse_args()
+
+ # make sure to parse \n from command line
+ message = args.message.replace("\\n", "\n")
+
+ # add some metadata and send message
+ t_now = time.time()
+ time_str = datetime.datetime.fromtimestamp(t_now).strftime('%m/%d/%Y %H:%M:%S')
+ message = "Hostname: `{}`\nProcess ID: `{}`\nTimestamp: `{}`\n```{}```".format(socket.gethostname(), os.getpid(), time_str, message)
+ give_slack_notif(message)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/hyperparam_helper.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/hyperparam_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..870c739ecbf4a751b6e62c363555d451cfa68ae2
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/hyperparam_helper.py
@@ -0,0 +1,141 @@
+"""
+A useful script for generating json files and shell scripts for conducting parameter scans.
+The script takes a path to a base json file as an argument and a shell file name.
+It generates a set of new json files in the same folder as the base json file, and
+a shell file script that contains commands to run for each experiment.
+
+Instructions:
+
+(1) Start with a base json that specifies a complete set of parameters for a single
+ run. This only needs to include parameters you want to sweep over, and parameters
+ that are different from the defaults. You can set this file path by either
+ passing it as an argument (e.g. --config /path/to/base.json) or by directly
+ setting the config file in @make_generator. The new experiment jsons will be put
+ into the same directory as the base json.
+
+(2) Decide on what json parameters you would like to sweep over, and fill those in as
+ keys in @make_generator below, taking note of the hierarchical key
+ formatting using "/" or ".". Fill in corresponding values for each - these will
+ be used in creating the experiment names, and for determining the range
+ of values to sweep. Parameters that should be sweeped together should
+ be assigned the same group number.
+
+(3) Set the output script name by either passing it as an argument (e.g. --script /path/to/script.sh)
+ or by directly setting the script file in @make_generator. The script to run all experiments
+ will be created at the specified path.
+
+Args:
+ config (str): path to a base config json file that will be modified to generate config jsons.
+ The jsons will be generated in the same folder as this file.
+
+ script (str): path to output script that contains commands to run the generated training runs
+
+Example usage:
+
+ # assumes that /tmp/gen_configs/base.json has already been created (see quickstart section of docs for an example)
+ python hyperparam_helper.py --config /tmp/gen_configs/base.json --script /tmp/gen_configs/out.sh
+"""
+import argparse
+
+import robomimic
+import robomimic.utils.hyperparam_utils as HyperparamUtils
+
+
+def make_generator(config_file, script_file):
+ """
+ Implement this function to setup your own hyperparameter scan!
+ """
+ generator = HyperparamUtils.ConfigGenerator(
+ base_config_file=config_file, script_file=script_file
+ )
+
+ # use RNN with horizon 10
+ generator.add_param(
+ key="algo.rnn.enabled",
+ name="",
+ group=0,
+ values=[True],
+ )
+ generator.add_param(
+ key="train.seq_length",
+ name="",
+ group=0,
+ values=[10],
+ )
+ generator.add_param(
+ key="algo.rnn.horizon",
+ name="",
+ group=0,
+ values=[10],
+ )
+
+ # LR - 1e-3, 1e-4
+ generator.add_param(
+ key="algo.optim_params.policy.learning_rate.initial",
+ name="plr",
+ group=1,
+ values=[1e-3, 1e-4],
+ )
+
+ # GMM y / n
+ generator.add_param(
+ key="algo.gmm.enabled",
+ name="gmm",
+ group=2,
+ values=[True, False],
+ value_names=["t", "f"],
+ )
+
+ # RNN dim 400 + MLP dims (1024, 1024) vs. RNN dim 1000 + empty MLP dims ()
+ generator.add_param(
+ key="algo.rnn.hidden_dim",
+ name="rnnd",
+ group=3,
+ values=[
+ 400,
+ 1000,
+ ],
+ )
+ generator.add_param(
+ key="algo.actor_layer_dims",
+ name="mlp",
+ group=3,
+ values=[
+ [1024, 1024],
+ [],
+ ],
+ value_names=["1024", "0"],
+ )
+
+ return generator
+
+
+def main(args):
+
+ # make config generator
+ generator = make_generator(config_file=args.config, script_file=args.script)
+
+ # generate jsons and script
+ generator.generate()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # Path to base json config - will override any defaults.
+ parser.add_argument(
+ "--config",
+ type=str,
+ help="path to base config json that will be modified to generate jsons. The jsons will\
+ be generated in the same folder as this file.",
+ )
+
+ # Script name to generate - will override any defaults
+ parser.add_argument(
+ "--script",
+ type=str,
+ help="path to output script that contains commands to run the generated training runs",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/hyperparam_helper_diffusion.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/hyperparam_helper_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..085d084be7a78e73ec149474a9ded44722880883
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/hyperparam_helper_diffusion.py
@@ -0,0 +1,297 @@
+"""
+Version of hyperparam helper to easily spin up runs with different base configs and diffusion policy.
+"""
+import os
+import shutil
+import json
+import argparse
+
+import robomimic
+import robomimic.utils.hyperparam_utils as HyperparamUtils
+
+import maglev_utils
+from maglev_utils.utils.file_utils import config_generator_to_script_lines
+
+
+# set base folder for where to copy each base config and generate new configs
+CONFIG_DIR = "/tmp/diffusion_configs"
+
+# path to base robomimic training config(s)
+BASE_CONFIGS = [
+ # "~/Desktop/mimicgen_env_data/base_train_diffusion.json",
+ # "~/Desktop/mimicgen_env_data/base_train_diffusion_image.json",
+ "~/Desktop/mimicgen_env_data/base_train_diffusion.json",
+]
+
+# output directory for this set of runs
+OUTPUT_DIR = "/tmp/diffusion_runs"
+
+
+def make_generators(base_configs):
+ """Helper function to make all generators."""
+ all_settings = [
+ # # low-dim
+ # dict(
+ # dataset_paths=[
+ # "/tmp/low_dim.hdf5",
+ # ],
+ # dataset_names=[
+ # "low_dim",
+ # ],
+ # horizon=400,
+ # ),
+ # # image
+ # dict(
+ # dataset_paths=[
+ # "/tmp/image.hdf5",
+ # ],
+ # dataset_names=[
+ # "image",
+ # ],
+ # horizon=400,
+ # ),
+ dict(
+ dataset_paths=[
+ "/ext2/rebuttal/diffusion/square_ph_abs_im.hdf5",
+ ],
+ dataset_names=[
+ "square_ph_ld",
+ ],
+ horizon=400,
+ ),
+ ]
+
+ assert len(base_configs) == len(all_settings)
+ ret = []
+ for conf, setting in zip(base_configs, all_settings):
+ ret.append(make_gen(os.path.expanduser(conf), setting))
+ return ret
+
+
+def make_gen(base_config, settings):
+ """
+ Specify training configs to generate here.
+ """
+ generator = HyperparamUtils.ConfigGenerator(
+ base_config_file=base_config,
+ script_file="", # will be overriden in next step
+ )
+
+ # add some params to sweep
+ dataset_values = [[dict(path=x)] for x in settings["dataset_paths"]]
+ generator.add_param(
+ key="train.data",
+ name="ds",
+ group=0,
+ values=dataset_values,
+ value_names=settings["dataset_names"],
+ )
+
+ # rollout settings
+ generator.add_param(
+ key="experiment.rollout.horizon",
+ name="",
+ group=1,
+ values=[settings["horizon"]],
+ )
+
+ # output path
+ generator.add_param(
+ key="train.output_dir",
+ name="",
+ group=2,
+ values=[
+ OUTPUT_DIR,
+ ],
+ )
+
+ # ensure robosuite env uses absolute pose actions
+ generator.add_param(
+ key="experiment.env_meta_update_dict",
+ name="",
+ group=-1,
+ values=[
+ {"env_kwargs": {"controller_configs": {"control_delta": False}}}
+ ],
+ )
+
+ # default action spec for diffusion policy
+ generator.add_param(
+ key="train.action_keys",
+ name="",
+ group=-1,
+ values=[
+ [
+ "action_dict/abs_pos",
+ "action_dict/abs_rot_6d",
+ "action_dict/gripper",
+ # "actions",
+ ],
+ ],
+ )
+ generator.add_param(
+ key="train.action_config",
+ name="",
+ group=-1,
+ values=[
+ {
+ "actions":{
+ "normalization": None,
+ },
+ "action_dict/abs_pos": {
+ "normalization": "min_max"
+ },
+ "action_dict/abs_rot_axis_angle": {
+ "normalization": "min_max",
+ "format": "rot_axis_angle"
+ },
+ "action_dict/abs_rot_6d": {
+ "normalization": None,
+ "format": "rot_6d"
+ },
+ "action_dict/rel_pos": {
+ "normalization": None,
+ },
+ "action_dict/rel_rot_axis_angle": {
+ "normalization": None,
+ "format": "rot_axis_angle"
+ },
+ "action_dict/rel_rot_6d": {
+ "normalization": None,
+ "format": "rot_6d"
+ },
+ "action_dict/gripper": {
+ "normalization": None,
+ }
+ }
+ ],
+ )
+
+ # num data workers 4 by default (for both low-dim and image) and cache mode "low_dim"
+ generator.add_param(
+ key="train.num_data_workers",
+ name="",
+ group=-1,
+ values=[4],
+ )
+ generator.add_param(
+ key="train.hdf5_cache_mode",
+ name="",
+ group=-1,
+ values=["low_dim"],
+ )
+
+ # num epochs 1000 for both low-dim and image
+ generator.add_param(
+ key="train.num_epochs",
+ name="",
+ group=-1,
+ values=[1000],
+ )
+
+ # set low-rate of eval - every 100 epochs
+ generator.add_param(
+ key="experiment.save.every_n_epochs",
+ name="",
+ group=-1,
+ values=[100],
+ )
+ generator.add_param(
+ key="experiment.rollout.rate",
+ name="",
+ group=-1,
+ values=[100],
+ )
+
+ # set noise scheduler
+ use_ddim = True
+ inf_steps = [(100, 10), (50, 5)]
+ # use_ddim = False
+ # inf_steps = []
+
+ generator.add_param(
+ key="algo.ddim.enabled",
+ name="ddim" if use_ddim else "",
+ group=1001,
+ values=[
+ use_ddim,
+ ],
+ value_names=[
+ "t" if use_ddim else "f",
+ ],
+ )
+ generator.add_param(
+ key="algo.ddpm.enabled",
+ name="ddpm" if not use_ddim else "",
+ group=1001,
+ values=[
+ (not use_ddim),
+ ],
+ value_names=[
+ "f" if not use_ddim else "t",
+ ],
+ )
+
+ if len(inf_steps) > 0:
+ train_inf_steps = [x[0] for x in inf_steps]
+ eval_inf_steps = [x[1] for x in inf_steps]
+ # set inf steps
+ generator.add_param(
+ key="algo.ddim.num_train_timesteps" if use_ddim else "algo.ddpm.num_train_timesteps",
+ name="train",
+ group=1002,
+ values=train_inf_steps,
+ )
+ generator.add_param(
+ key="algo.ddim.num_inference_timesteps" if use_ddim else "algo.ddpm.num_inference_timesteps",
+ name="eval",
+ group=1002,
+ values=eval_inf_steps,
+ )
+
+ # # seed
+ # generator.add_param(
+ # key="train.seed",
+ # name="seed",
+ # group=100000,
+ # values=[101, 102, 103],
+ # )
+
+ return generator
+
+
+def main(args):
+
+ # make config generators
+ generators = make_generators(base_configs=BASE_CONFIGS)
+
+ if args.config_dir is None:
+ args.config_dir = CONFIG_DIR
+
+ if os.path.exists(args.config_dir):
+ ans = input("Non-empty dir at {} will be removed.\nContinue (y / n)? \n".format(args.config_dir))
+ if ans != "y":
+ exit()
+ shutil.rmtree(args.config_dir)
+
+ all_json_files, run_lines = config_generator_to_script_lines(generators, config_dir=args.config_dir)
+
+ print("configs")
+ print(json.dumps(all_json_files, indent=4))
+ print("runs")
+ print(json.dumps(run_lines, indent=4))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # Path to base json config - will override any defaults.
+ parser.add_argument(
+ "--config_dir",
+ type=str,
+ help="path to base config json that will be modified to generate jsons. The jsons will\
+ be generated in the same folder as this file.",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/playback_dataset.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/playback_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3251ba6571e2c5df2a476655ea2c84fd57c7fe5e
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/playback_dataset.py
@@ -0,0 +1,514 @@
+"""
+A script to visualize dataset trajectories by loading the simulation states
+one by one or loading the first state and playing actions back open-loop.
+The script can generate videos as well, by rendering simulation frames
+during playback. The videos can also be generated using the image observations
+in the dataset (this is useful for real-robot datasets) by using the
+--use-obs argument.
+
+Args:
+ dataset (str): path to hdf5 dataset
+
+ filter_key (str): if provided, use the subset of trajectories
+ in the file that correspond to this filter key
+
+ n (int): if provided, stop after n trajectories are processed
+
+ use-obs (bool): if flag is provided, visualize trajectories with dataset
+ image observations instead of simulator
+
+ use-actions (bool): if flag is provided, use open-loop action playback
+ instead of loading sim states
+
+ render (bool): if flag is provided, use on-screen rendering during playback
+
+ video_path (str): if provided, render trajectories to this video file path
+
+ video_skip (int): render frames to a video every @video_skip steps
+
+ render_image_names (str or [str]): camera name(s) / image observation(s) to
+ use for rendering on-screen or to video
+
+ first (bool): if flag is provided, use first frame of each episode for playback
+ instead of the entire episode. Useful for visualizing task initializations.
+
+Example usage below:
+
+ # force simulation states one by one, and render agentview and wrist view cameras to video
+ python playback_dataset.py --dataset /path/to/dataset.hdf5 \
+ --render_image_names agentview robot0_eye_in_hand \
+ --video_path /tmp/playback_dataset.mp4
+
+ # playback the actions in the dataset, and render agentview camera during playback to video
+ python playback_dataset.py --dataset /path/to/dataset.hdf5 \
+ --use-actions --render_image_names agentview \
+ --video_path /tmp/playback_dataset_with_actions.mp4
+
+ # use the observations stored in the dataset to render videos of the dataset trajectories
+ python playback_dataset.py --dataset /path/to/dataset.hdf5 \
+ --use-obs --render_image_names agentview_image \
+ --video_path /tmp/obs_trajectory.mp4
+
+ # visualize depth observations along with image observations
+ python playback_dataset.py --dataset /path/to/dataset.hdf5 \
+ --use-obs --render_image_names agentview_image \
+ --render_depth_names agentview_depth \
+ --video_path /tmp/obs_trajectory.mp4
+
+ # visualize initial states in the demonstration data
+ python playback_dataset.py --dataset /path/to/dataset.hdf5 \
+ --first --render_image_names agentview \
+ --video_path /tmp/dataset_task_inits.mp4
+"""
+
+import os
+import json
+import h5py
+import argparse
+import imageio
+import matplotlib.pyplot as plt
+import matplotlib.cm as cm
+import numpy as np
+
+import robomimic
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.env_utils as EnvUtils
+import robomimic.utils.file_utils as FileUtils
+from robomimic.utils.vis_utils import depth_to_rgb
+from robomimic.envs.env_base import EnvBase, EnvType
+
+try:
+ import mimicgen
+except ImportError:
+ print("WARNING: could not import mimicgen envs")
+
+
+# Define default cameras to use for each env type
+DEFAULT_CAMERAS = {
+ EnvType.ROBOSUITE_TYPE: ["agentview"],
+ EnvType.IG_MOMART_TYPE: ["rgb"],
+ EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"),
+ EnvType.REAL_TYPE: ["front_image"],
+ EnvType.GPRS_REAL_TYPE: ["front_image"],
+}
+
+
+def add_red_border(frame):
+ """Add a red border to image frame."""
+ border_size = int(0.05 * min(frame.shape[0], frame.shape[1])) # 5% of image
+ frame[:border_size, :, :] = [255., 0., 0.]
+ frame[-border_size:, :, :] = [255., 0., 0.]
+ frame[:, :border_size, :] = [255., 0., 0.]
+ frame[:, -border_size:, :] = [255., 0., 0.]
+ return frame
+
+
+def depth_to_rgb(depth_map, depth_min=None, depth_max=None):
+ """
+ Convert depth map to rgb array by computing normalized depth values in [0, 1].
+ """
+ # normalize depth map into [0, 1]
+ if depth_min is None:
+ depth_min = depth_map.min()
+ if depth_max is None:
+ depth_max = depth_map.max()
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
+ # depth_map = np.clip(depth_map / 3., 0., 1.)
+ if len(depth_map.shape) == 3:
+ assert depth_map.shape[-1] == 1
+ depth_map = depth_map[..., 0]
+ assert len(depth_map.shape) == 2 # [H, W]
+ return (255. * cm.hot(depth_map, 3)).astype(np.uint8)[..., :3]
+
+
+def playback_trajectory_with_env(
+ env,
+ initial_state,
+ states,
+ actions=None,
+ render=False,
+ video_writer=None,
+ video_skip=5,
+ camera_names=None,
+ first=False,
+ interventions=None,
+ real=False,
+):
+ """
+ Helper function to playback a single trajectory using the simulator environment.
+ If @actions are not None, it will play them open-loop after loading the initial state.
+ Otherwise, @states are loaded one by one.
+
+ Args:
+ env (instance of EnvBase): environment
+ initial_state (dict): initial simulation state to load
+ states (list of dict or np.array): array of simulation states to load
+ actions (np.array): if provided, play actions back open-loop instead of using @states
+ render (bool): if True, render on-screen
+ video_writer (imageio writer): video writer
+ video_skip (int): determines rate at which environment frames are written to video
+ camera_names (list): determines which camera(s) are used for rendering. Pass more than
+ one to output a video with multiple camera views concatenated horizontally.
+ first (bool): if True, only use the first frame of each episode.
+ real (bool): if True, playback is happening on real robot
+ """
+ assert isinstance(env, EnvBase)
+
+ write_video = (video_writer is not None)
+ video_count = 0
+ assert not (render and write_video)
+
+ # load the initial state
+ env.reset()
+ if real:
+ assert actions is not None, "must supply actions for real robot playback"
+ traj_len = actions.shape[0]
+ input("ready for next episode? hit enter to continue")
+ else:
+ env.reset_to(initial_state)
+ traj_len = len(states)
+
+ action_playback = (actions is not None)
+ if action_playback:
+ assert len(states) == actions.shape[0]
+
+ for i in range(traj_len):
+ if action_playback:
+ env.step(actions[i])
+ if (i < traj_len - 1) and not real:
+ # check whether the actions deterministically lead to the same recorded states
+ state_playback = env.get_state()["states"]
+ if isinstance(state_playback, dict):
+ # state is dict, so assert equality for all keys
+ for k in state_playback:
+ if not np.all(np.equal(states[i + 1][k], state_playback[k])):
+ err = np.linalg.norm(states[i + 1][k] - state_playback[k])
+ print("warning: playback diverged by {} at step {} state key {}".format(err, i, k))
+ else:
+ if not np.all(np.equal(states[i + 1], state_playback)):
+ err = np.linalg.norm(states[i + 1] - state_playback)
+ print("warning: playback diverged by {} at step {}".format(err, i))
+
+ else:
+ env.reset_to({"states" : states[i]})
+
+ # on-screen render
+ if render:
+ env.render(mode="human", camera_name=camera_names[0])
+
+ # video render
+ if write_video:
+ if video_count % video_skip == 0:
+ video_img = []
+ for cam_name in camera_names:
+ frame = env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name)
+ if (interventions is not None) and interventions[i]:
+ # add red border to frame
+ frame = add_red_border(frame=frame)
+ video_img.append(frame)
+ video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
+ video_writer.append_data(video_img)
+ video_count += 1
+
+ if first:
+ break
+
+
+def playback_trajectory_with_obs(
+ traj_grp,
+ video_writer,
+ video_skip=5,
+ image_names=None,
+ depth_names=None,
+ first=False,
+ intervention=False,
+):
+ """
+ This function reads all "rgb" (and possibly "depth") observations in the dataset trajectory and
+ writes them into a video.
+
+ Args:
+ traj_grp (hdf5 file group): hdf5 group which corresponds to the dataset trajectory to playback
+ video_writer (imageio writer): video writer
+ video_skip (int): determines rate at which environment frames are written to video
+ image_names (list): determines which image observations are used for rendering. Pass more than
+ one to output a video with multiple image observations concatenated horizontally.
+ depth_names (list): determines which depth observations are used for rendering (if any).
+ first (bool): if True, only use the first frame of each episode.
+ intervention (bool): if True, denote intervention timesteps with a red border
+ """
+ assert image_names is not None, "error: must specify at least one image observation to use in @image_names"
+ video_count = 0
+
+ if depth_names is not None:
+ # compute min and max depth value across trajectory for normalization
+ depth_min = { k : traj_grp["obs/{}".format(k)][:].min() for k in depth_names }
+ depth_max = { k : traj_grp["obs/{}".format(k)][:].max() for k in depth_names }
+
+ traj_len = traj_grp["actions"].shape[0]
+ frame_inds = range(traj_len)
+ if first:
+ video_skip = 1 # keep all frames
+ if intervention:
+ # find where interventions begin (0 to 1 edge) and get frames right before them
+ if len(traj_grp["interventions"].shape) == 2:
+ all_interventions = traj_grp["interventions"][:, 0].astype(int)
+ else:
+ all_interventions = traj_grp["interventions"][:].astype(int)
+ frame_inds = list(np.nonzero((all_interventions[1:] - all_interventions[:-1]) > 0)[0])
+ else:
+ frame_inds = range(1)
+
+ if depth_names is not None:
+ # compute min and max depth value across trajectory for normalization
+ depth_min = { k : traj_grp["obs/{}".format(k)][:].min() for k in depth_names }
+ depth_max = { k : traj_grp["obs/{}".format(k)][:].max() for k in depth_names }
+
+ for i in frame_inds:
+ if video_count % video_skip == 0:
+ # concatenate image obs together
+ im = [traj_grp["obs/{}".format(k)][i] for k in image_names]
+ depth = [depth_to_rgb(traj_grp["obs/{}".format(k)][i], depth_min=depth_min[k], depth_max=depth_max[k]) for k in depth_names] if depth_names is not None else []
+ frame = np.concatenate(im + depth, axis=1)
+ video_writer.append_data(frame)
+ video_count += 1
+
+
+def playback_dataset(args, env=None):
+ # some arg checking
+ write_video = (args.video_path is not None)
+ assert not (args.render and write_video) # either on-screen or video but not both
+ if args.absolute:
+ assert args.use_actions
+
+ # Auto-fill camera rendering info if not specified
+ if args.render_image_names is None:
+ # We fill in the automatic values
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
+ env_type = EnvUtils.get_env_type(env_meta=env_meta)
+ args.render_image_names = DEFAULT_CAMERAS[env_type]
+
+ if args.render:
+ # on-screen rendering can only support one camera
+ assert len(args.render_image_names) == 1
+
+ if args.use_obs:
+ assert write_video, "playback with observations can only write to video"
+ assert not args.use_actions, "playback with observations is offline and does not support action playback"
+
+ if args.render_depth_names is not None:
+ assert args.use_obs, "depth observations can only be visualized from observations currently"
+
+ # create environment only if not playing back with observations
+ if not args.use_obs:
+ # need to make sure ObsUtils knows which observations are images, but it doesn't matter
+ # for playback since observations are unused. Pass a dummy spec here.
+ dummy_spec = dict(
+ obs=dict(
+ low_dim=["robot0_eef_pos"],
+ rgb=[],
+ ),
+ )
+
+ # some operations for playback are env-type-specific
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
+ is_robosuite_env = EnvUtils.is_robosuite_env(env_meta)
+ is_real_robot = EnvUtils.is_real_robot_env(env_meta) or EnvUtils.is_real_robot_gprs_env(env_meta)
+
+ if args.absolute:
+ # modify env-meta to tell the environment to expect absolute actions
+ assert is_robosuite_env or is_real_robot, "only these support absolute actions for now"
+ if is_robosuite_env:
+ env_meta["env_kwargs"]["controller_configs"]["control_delta"] = False
+ else:
+ env_meta["env_kwargs"]["absolute_actions"] = True
+
+ if env is None:
+ if is_real_robot:
+ # TODO: update hardcoded keys on real robot
+ dummy_spec["obs"]["rgb"] = ["front_image", "wrist_image", "side_image"]
+ dummy_spec["obs"]["depth"] = ["front_image_depth", "wrist_image_depth", "side_image_depth"]
+ ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)
+ env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=args.render, render_offscreen=write_video)
+
+ f = h5py.File(args.dataset, "r")
+
+ # list of all demonstration episodes (sorted in increasing number order)
+ if args.filter_key is not None:
+ print("using filter key: {}".format(args.filter_key))
+ demos = [elem.decode("utf-8") for elem in np.array(f["mask/{}".format(args.filter_key)])]
+ else:
+ demos = list(f["data"].keys())
+ inds = np.argsort([int(elem[5:]) for elem in demos])
+ demos = [demos[i] for i in inds]
+
+ # maybe reduce the number of demonstrations to playback
+ if args.n is not None:
+ demos = demos[:args.n]
+
+ # maybe dump video
+ video_writer = None
+ if write_video:
+ fps = 5 if args.first else 20
+ video_writer = imageio.get_writer(args.video_path, fps=fps)
+
+ for ind in range(len(demos)):
+ ep = demos[ind]
+ print("Playing back episode: {}".format(ep))
+
+ if args.use_obs:
+ playback_trajectory_with_obs(
+ traj_grp=f["data/{}".format(ep)],
+ video_writer=video_writer,
+ video_skip=args.video_skip,
+ image_names=args.render_image_names,
+ depth_names=args.render_depth_names,
+ first=args.first,
+ intervention=args.intervention,
+ )
+ continue
+
+ # prepare states to reload from
+ if not is_real_robot:
+ states = f["data/{}/states".format(ep)][()]
+ initial_state = dict(states=states[0])
+ if is_robosuite_env:
+ initial_state["model"] = f["data/{}".format(ep)].attrs["model_file"]
+
+ # supply actions if using open-loop action playback
+ actions = None
+ if args.use_actions:
+ if args.absolute:
+ actions = f["data/{}/actions_abs".format(ep)][()]
+ else:
+ actions = f["data/{}/actions".format(ep)][()]
+
+ if is_real_robot:
+ assert actions is not None
+ states = np.zeros(actions.shape[0])
+ initial_state = dict(states=states[0])
+
+ # supply interventions if we need them for visualization
+ interventions = None
+ if args.intervention:
+ interventions = f["data/{}/interventions".format(ep)][()]
+
+ playback_trajectory_with_env(
+ env=env,
+ initial_state=initial_state,
+ states=states, actions=actions,
+ render=args.render,
+ video_writer=video_writer,
+ video_skip=args.video_skip,
+ camera_names=args.render_image_names,
+ first=args.first,
+ interventions=interventions,
+ real=is_real_robot,
+ )
+
+ f.close()
+ if write_video:
+ video_writer.close()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="path to hdf5 dataset",
+ )
+ parser.add_argument(
+ "--filter_key",
+ type=str,
+ default=None,
+ help="(optional) filter key, to select a subset of trajectories in the file",
+ )
+
+ # number of trajectories to playback. If omitted, playback all of them.
+ parser.add_argument(
+ "--n",
+ type=int,
+ default=None,
+ help="(optional) stop after n trajectories are played",
+ )
+
+ # Use image observations instead of doing playback using the simulator env.
+ parser.add_argument(
+ "--use-obs",
+ action='store_true',
+ help="visualize trajectories with dataset image observations instead of simulator",
+ )
+
+ # Playback stored dataset actions open-loop instead of loading from simulation states.
+ parser.add_argument(
+ "--use-actions",
+ action='store_true',
+ help="use open-loop action playback instead of loading sim states",
+ )
+
+ # TODO: clean up this arg
+ parser.add_argument(
+ "--absolute",
+ action='store_true',
+ help="use absolute actions for open-loop action playback",
+ )
+
+ # Whether to render playback to screen
+ parser.add_argument(
+ "--render",
+ action='store_true',
+ help="on-screen rendering",
+ )
+
+ # Dump a video of the dataset playback to the specified path
+ parser.add_argument(
+ "--video_path",
+ type=str,
+ default=None,
+ help="(optional) render trajectories to this video file path",
+ )
+
+ # How often to write video frames during the playback
+ parser.add_argument(
+ "--video_skip",
+ type=int,
+ default=5,
+ help="render frames to video every n steps",
+ )
+
+ # camera names to render, or image observations to use for writing to video
+ parser.add_argument(
+ "--render_image_names",
+ type=str,
+ nargs='+',
+ default=None,
+ help="(optional) camera name(s) / image observation(s) to use for rendering on-screen or to video. Default is"
+ "None, which corresponds to a predefined camera for each env type",
+ )
+
+ # depth observations to use for writing to video
+ parser.add_argument(
+ "--render_depth_names",
+ type=str,
+ nargs='+',
+ default=None,
+ help="(optional) depth observation(s) to use for rendering to video"
+ )
+
+ # Only use the first frame of each episode
+ parser.add_argument(
+ "--first",
+ action='store_true',
+ help="use first frame of each episode",
+ )
+
+ # Denote intervention timesteps with a red border in the frame
+ parser.add_argument(
+ "--intervention",
+ action='store_true',
+ help="denote intervention timesteps with a red border in the frame",
+ )
+
+ args = parser.parse_args()
+ playback_dataset(args)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/postprocess_dataset_intervention_segments.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/postprocess_dataset_intervention_segments.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e9fec99aba334a9ba8298c3e73b6b391e29e921
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/postprocess_dataset_intervention_segments.py
@@ -0,0 +1,220 @@
+"""
+Script to postprocess a dataset by splitting each trajectory up into new trajectories
+that only consists of continuous intervention segments.
+"""
+import os
+import json
+import h5py
+import argparse
+import numpy as np
+
+import robomimic.utils.file_utils as FileUtils
+
+
+def write_intervention_segments_as_trajectories(
+ src_ep_grp,
+ dst_grp,
+ start_ep_write_ind,
+ same=False,
+):
+ """
+ Helper function to extract intervention segments from a source demonstration and use their indices to
+ write the corresponding subset of each trajectory to a new trajectory.
+
+ Returns:
+ end_ep_write_ind (int): updated episode index after writing trajectories to destination file
+ num_traj (int): number of trajectories written to destination file
+ total_samples (int): total number of samples written to destination file
+ same (bool): if True, write all intevrention segments to the same trajectory
+ """
+
+ # get segments
+ interventions = src_ep_grp["interventions"][()].reshape(-1).astype(int)
+ segments = FileUtils.get_intervention_segments(interventions)
+
+ ep_write_ind = start_ep_write_ind
+ total_samples = 0
+ num_traj = len(segments)
+ keys_to_try_and_copy = ["states", "obs", "next_obs", "rewards", "dones", "actions_abs", "datagen_info"]
+
+ if same:
+ # concatenate information across intervention segments and write to single episode
+ num_traj = 1
+ dst_grp_name = "demo_{}".format(ep_write_ind)
+ dst_ep_grp = dst_grp.create_group(dst_grp_name)
+ for k in keys_to_try_and_copy:
+ should_compress = (k in ["obs", "next_obs"])
+ if k in src_ep_grp:
+ if isinstance(src_ep_grp[k], h5py.Group):
+ for k2 in src_ep_grp[k]:
+ assert isinstance(src_ep_grp[k][k2], h5py.Dataset)
+ data = np.concatenate(
+ [src_ep_grp[k][k2][seg_start_ind : seg_end_ind] for seg_start_ind, seg_end_ind in segments],
+ axis=0,
+ )
+ if should_compress:
+ dst_ep_grp.create_dataset("{}/{}".format(k, k2), data=data, compression="gzip")
+ else:
+ dst_ep_grp.create_dataset("{}/{}".format(k, k2), data=data)
+ else:
+ assert isinstance(src_ep_grp[k], h5py.Dataset)
+ data = np.concatenate(
+ [src_ep_grp[k][seg_start_ind : seg_end_ind] for seg_start_ind, seg_end_ind in segments],
+ axis=0,
+ )
+ if should_compress:
+ dst_ep_grp.create_dataset("{}".format(k), data=data, compression="gzip")
+ else:
+ dst_ep_grp.create_dataset("{}".format(k), data=data)
+
+ # manually copy actions since they might need truncation
+ actions = np.concatenate([src_ep_grp["actions"][seg_start_ind : seg_end_ind] for seg_start_ind, seg_end_ind in segments], axis=0)
+ if actions.shape[-1] != 7:
+ actions = actions[..., :7]
+ dst_ep_grp.create_dataset("actions", data=actions)
+
+ # mimicgen metadata
+ if "src_demo_inds" in src_ep_grp:
+ dst_ep_grp.create_dataset("src_demo_inds", data=np.array(src_ep_grp["src_demo_inds"][:]))
+ if "src_demo_labels" in src_ep_grp:
+ dst_ep_grp.create_dataset("src_demo_labels", data=np.array(src_ep_grp["src_demo_labels"][:]))
+
+ # copy attributes too
+ for k in src_ep_grp.attrs:
+ dst_ep_grp.attrs[k] = src_ep_grp.attrs[k]
+ dst_ep_grp.attrs["num_samples"] = np.sum([(seg_end_ind - seg_start_ind) for seg_start_ind, seg_end_ind in segments])
+
+ # update variables for next iter
+ ep_write_ind += 1
+ total_samples += dst_ep_grp.attrs["num_samples"]
+ print(" wrote trajectory to dst grp {} with num samples {}".format(dst_grp_name, dst_ep_grp.attrs["num_samples"]))
+ else:
+ # write each segment to new episode
+ for seg_start_ind, seg_end_ind in segments:
+ dst_grp_name = "demo_{}".format(ep_write_ind)
+ dst_ep_grp = dst_grp.create_group(dst_grp_name)
+
+ # copy over subsequence from source trajectory into destination trajectory
+ for k in keys_to_try_and_copy:
+ should_compress = (k in ["obs", "next_obs"])
+ if k in src_ep_grp:
+ if isinstance(src_ep_grp[k], h5py.Group):
+ for k2 in src_ep_grp[k]:
+ assert isinstance(src_ep_grp[k][k2], h5py.Dataset)
+ if should_compress:
+ dst_ep_grp.create_dataset("{}/{}".format(k, k2), data=np.array(src_ep_grp[k][k2][seg_start_ind : seg_end_ind]), compression="gzip")
+ else:
+ dst_ep_grp.create_dataset("{}/{}".format(k, k2), data=np.array(src_ep_grp[k][k2][seg_start_ind : seg_end_ind]))
+ else:
+ assert isinstance(src_ep_grp[k], h5py.Dataset)
+ if should_compress:
+ dst_ep_grp.create_dataset("{}".format(k), data=np.array(src_ep_grp[k][seg_start_ind : seg_end_ind]), compression="gzip")
+ else:
+ dst_ep_grp.create_dataset("{}".format(k), data=np.array(src_ep_grp[k][seg_start_ind : seg_end_ind]))
+
+ # manually copy actions since they might need truncation
+ actions = np.array(src_ep_grp["actions"][seg_start_ind : seg_end_ind])
+ if actions.shape[-1] != 7:
+ actions = actions[..., :7]
+ dst_ep_grp.create_dataset("actions", data=actions)
+
+ # mimicgen metadata
+ if "src_demo_inds" in src_ep_grp:
+ dst_ep_grp.create_dataset("src_demo_inds", data=np.array(src_ep_grp["src_demo_inds"][:]))
+ if "src_demo_labels" in src_ep_grp:
+ dst_ep_grp.create_dataset("src_demo_labels", data=np.array(src_ep_grp["src_demo_labels"][:]))
+
+ # copy attributes too
+ for k in src_ep_grp.attrs:
+ dst_ep_grp.attrs[k] = src_ep_grp.attrs[k]
+ dst_ep_grp.attrs["num_samples"] = (seg_end_ind - seg_start_ind)
+
+ # update variables for next iter
+ ep_write_ind += 1
+ total_samples += dst_ep_grp.attrs["num_samples"]
+ print(" wrote trajectory to dst grp {} with num samples {}".format(dst_grp_name, dst_ep_grp.attrs["num_samples"]))
+
+ return ep_write_ind, num_traj, total_samples
+
+
+def postprocess_dataset_intervention_segments(args):
+ # list of all demonstration episodes (sorted in increasing number order)
+ f = h5py.File(args.dataset, "r")
+ demos = list(f["data"].keys())
+ inds = np.argsort([int(elem[5:]) for elem in demos])
+ demos = [demos[i] for i in inds]
+
+ # maybe reduce the number of demonstrations to playback
+ if args.n is not None:
+ demos = demos[:args.n]
+
+ # output file in same directory as input file
+ output_path = os.path.join(os.path.dirname(args.dataset), args.output_name)
+ f_out = h5py.File(output_path, "w")
+ data_grp = f_out.create_group("data")
+ print("\ninput file: {}".format(args.dataset))
+ print("output file: {}\n".format(output_path))
+
+ ep_write_ind = 0
+ num_traj = 0
+ total_samples = 0
+ for ind in range(len(demos)):
+ ep = demos[ind]
+ src_ep_grp = f["data/{}".format(ep)]
+ print("src grp: {} with {} samples".format(ep, src_ep_grp.attrs["num_samples"]))
+ ep_write_ind, ep_num_traj, ep_total_samples = write_intervention_segments_as_trajectories(
+ src_ep_grp=src_ep_grp,
+ dst_grp=data_grp,
+ start_ep_write_ind=ep_write_ind,
+ same=args.same,
+ )
+ num_traj += ep_num_traj
+ total_samples += ep_total_samples
+
+ # TODO: update filter keys based on which source demos created which target demos
+ # if "mask" in f:
+ # f.copy("mask", f_out)
+
+ # global metadata
+ data_grp.attrs["total"] = total_samples
+ data_grp.attrs["env_args"] = f["data"].attrs["env_args"] # environment info
+ print("\nWrote {} trajectories from src with {} trajectories to {}".format(num_traj, len(demos), output_path))
+
+ f.close()
+ f_out.close()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ required=True,
+ help="path to input hdf5 dataset",
+ )
+ # name of hdf5 to write - it will be in the same directory as @dataset
+ parser.add_argument(
+ "--output_name",
+ type=str,
+ required=True,
+ help="name of output hdf5 dataset",
+ )
+
+ # specify number of demos to process - useful for debugging conversion with a handful
+ # of trajectories
+ parser.add_argument(
+ "--n",
+ type=int,
+ default=None,
+ help="(optional) stop after n trajectories are processed",
+ )
+
+ # write all intervention segments to the same demo key (so they will be treated as a contiguous trajectory in time)
+ parser.add_argument(
+ "--same",
+ action='store_true',
+ help="write all intervention segments to the same demo key (so they will be treated as a contiguous trajectory in time",
+ )
+
+ args = parser.parse_args()
+ postprocess_dataset_intervention_segments(args)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/remove_idle_segments.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/remove_idle_segments.py
new file mode 100644
index 0000000000000000000000000000000000000000..35d639363b384a69c87ab93d9beb80e916cd6f2f
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/remove_idle_segments.py
@@ -0,0 +1,199 @@
+"""
+Script to remove idle segments from a real robot hdf5.
+"""
+import os
+import h5py
+import argparse
+import numpy as np
+from tqdm import tqdm
+
+import robomimic.utils.file_utils as FileUtils
+from robomimic.scripts.postprocess_dataset_intervention_segments import postprocess_dataset_intervention_segments
+
+
+def get_idle_segments_in_trajectory(
+ ep_grp,
+ obs_pos_key="ee_pose",
+ min_segment_length=1,
+ threshold=1e-4,
+ verbose=False,
+):
+ """
+ Returns a mask that corresponds to idle segments in the trajectory.
+
+ Args:
+ ep_grp (h5py.Group): hdf5 group that corresponds to a demo key (such as "demo_0")
+ obs_pos_key (str): key for eef pos observations
+ min_segment_length (int): minimum length of idle segment
+ threshold (float): threshold for delta eef pos differences - everything below this threshold
+ value is considered idle
+ verbose (bool): if True, print some helpful info
+
+ Returns:
+ idle_segment_mask (np.array): array with value of 1 during an idle segment
+ """
+ if verbose:
+ print(ep_grp)
+ eef_pos = ep_grp["obs/{}".format(obs_pos_key)][:, :3]
+ delta_eef_pos_norms = np.linalg.norm(np.diff(eef_pos, axis=0), axis=1)
+
+ # note: pad with 0 at start to make sure indices correspond to indices in @eef_pos (otherwise we're off by one due to the difference calculation)
+ idle_segment_mask = np.array([0] + (delta_eef_pos_norms < threshold).astype(int).tolist())
+ idle_segments = FileUtils.get_intervention_segments(idle_segment_mask)
+
+ # filter out segments that are too short
+ ret_mask = np.zeros(eef_pos.shape[0]).astype(int)
+ for seg in idle_segments:
+ if seg[1] - seg[0] >= min_segment_length:
+ ret_mask[seg[0] : seg[1]] = 1
+
+ if verbose:
+ print("segment: {}".format(seg))
+ # print norms N timesteps before and after window to get a sense of nearby values
+ prev_norms = delta_eef_pos_norms[max(seg[0] - 6, 0) : seg[0] - 1]
+ print("prev_norms")
+ print(prev_norms)
+ post_norms = delta_eef_pos_norms[seg[1] - 1 : min(seg[1] + 4, eef_pos.shape[0] - 1)]
+ print("post_norms")
+ print(post_norms)
+
+ return ret_mask
+
+
+def write_non_idle_segments_as_interventions(hdf5_path, n=None, min_segment_length=1, threshold=1e-4):
+ """
+ Modifies the hdf5 in-place by splitting each trajectory into idle and non-idle segments, and
+ writing the result as an "interventions" key in each trajectory, where the interventions correspond
+ to non-idle segments.
+ """
+
+ # get demo keys
+ f = h5py.File(args.dataset, "a")
+ demos = list(f["data"].keys())
+ inds = np.argsort([int(elem[5:]) for elem in demos])
+ demos = [demos[i] for i in inds]
+ if args.n is not None:
+ demos = demos[:args.n]
+
+ # for each demo key, get idle segment, and write to interventions
+ for demo_key in tqdm(demos):
+ ep_grp = f["data/{}".format(demo_key)]
+ idle_seg_mask = get_idle_segments_in_trajectory(
+ ep_grp=ep_grp,
+ obs_pos_key="ee_pose",
+ min_segment_length=min_segment_length,
+ threshold=threshold,
+ )
+
+ # write non-idle segment mask as interventions
+ non_idle_seg_mask = 1 - idle_seg_mask
+ if "interventions" in ep_grp:
+ del ep_grp["interventions"]
+ ep_grp.create_dataset("interventions", data=non_idle_seg_mask)
+
+ f.close()
+
+
+def combine_intervention_segments(hdf5_path, output_name, n=None):
+ """
+ Helper function to combine intervention segments in each demo trajectory together, and discard
+ non-intervention segments. This repurposes the postprocess_dataset_intervention_segments.py to
+ essentially remove the idle segments (which are non-intervention segments).
+ """
+ args = argparse.Namespace()
+ args.dataset = os.path.expandvars(os.path.expanduser(hdf5_path))
+ args.output_name = output_name
+ args.n = n
+ args.same = True
+ postprocess_dataset_intervention_segments(args)
+
+
+def remove_idle_segments(args):
+ if args.debug:
+ # print idle segments for the demos
+
+ # get demo keys
+ f = h5py.File(args.dataset, "r")
+ demos = list(f["data"].keys())
+ inds = np.argsort([int(elem[5:]) for elem in demos])
+ demos = [demos[i] for i in inds]
+ if args.n is not None:
+ demos = demos[:args.n]
+
+ for demo_key in demos:
+ idle_seg_mask = get_idle_segments_in_trajectory(
+ ep_grp=f["data/{}".format(demo_key)],
+ obs_pos_key="ee_pose",
+ # min_segment_length=1,
+ min_segment_length=7,
+ threshold=1e-4,
+ # threshold=3e-4,
+ # verbose=True,
+ verbose=False,
+ )
+ idle_segs = FileUtils.get_intervention_segments(idle_seg_mask)
+ print(demo_key)
+ # print(len(idle_segs))
+ print("idle segments")
+ print(idle_segs)
+ print("segment lengths")
+ print([seg[1] - seg[0] for seg in idle_segs])
+
+ f.close()
+ exit()
+
+ assert args.output_name is not None
+
+ # split each trajectory into idle and non-idle segments and write to "interventions" key
+ print("writing non-idle segments as interventions...")
+ write_non_idle_segments_as_interventions(
+ hdf5_path=args.dataset,
+ n=args.n,
+ # some good candidates below
+ min_segment_length=7,
+ threshold=1e-4,
+ # min_segment_length=7,
+ # threshold=3e-4,
+ )
+
+ # write new dataset, keeping only interventions
+ print("combining interventions into new dataset...")
+ combine_intervention_segments(
+ hdf5_path=args.dataset,
+ output_name=args.output_name,
+ n=args.n,
+ )
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ required=True,
+ help="path to input hdf5 dataset",
+ )
+ # name of hdf5 to write - it will be in the same directory as @dataset
+ parser.add_argument(
+ "--output_name",
+ type=str,
+ default=None,
+ help="name of output hdf5 dataset",
+ )
+
+ # specify number of demos to process - useful for debugging conversion with a handful
+ # of trajectories
+ parser.add_argument(
+ "--n",
+ type=int,
+ default=None,
+ help="(optional) stop after n trajectories are processed",
+ )
+
+ parser.add_argument(
+ "--debug",
+ action='store_true',
+ help="just print the idle and non-idle segment splits instead of actually doing any dataset processing",
+ )
+
+ args = parser.parse_args()
+ remove_idle_segments(args)
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/run_trained_agent.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/run_trained_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7ef99cdc7742e0acbc4c7c9dabe51b7e8198187
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/run_trained_agent.py
@@ -0,0 +1,536 @@
+"""
+The main script for evaluating a policy in an environment.
+
+Args:
+ agent (str): path to saved checkpoint pth file
+
+ horizon (int): if provided, override maximum horizon of rollout from the one
+ in the checkpoint
+
+ env (str): if provided, override name of env from the one in the checkpoint,
+ and use it for rollouts
+
+ render (bool): if flag is provided, use on-screen rendering during rollouts
+
+ video_path (str): if provided, render trajectories to this video file path
+
+ video_skip (int): render frames to a video every @video_skip steps
+
+ camera_names (str or [str]): camera name(s) to use for rendering on-screen or to video
+
+ dataset_path (str): if provided, an hdf5 file will be written at this path with the
+ rollout data
+
+ dataset_obs (bool): if flag is provided, and @dataset_path is provided, include
+ possible high-dimensional observations in output dataset hdf5 file (by default,
+ observations are excluded and only simulator states are saved).
+
+ seed (int): if provided, set seed for rollouts
+
+Example usage:
+
+ # Evaluate a policy with 50 rollouts of maximum horizon 400 and save the rollouts to a video.
+ # Visualize the agentview and wrist cameras during the rollout.
+
+ python run_trained_agent.py --agent /path/to/model.pth \
+ --n_rollouts 50 --horizon 400 --seed 0 \
+ --video_path /path/to/output.mp4 \
+ --camera_names agentview robot0_eye_in_hand
+
+ # Write the 50 agent rollouts to a new dataset hdf5.
+
+ python run_trained_agent.py --agent /path/to/model.pth \
+ --n_rollouts 50 --horizon 400 --seed 0 \
+ --dataset_path /path/to/output.hdf5 --dataset_obs
+
+ # Write the 50 agent rollouts to a new dataset hdf5, but exclude the dataset observations
+ # since they might be high-dimensional (they can be extracted again using the
+ # dataset_states_to_obs.py script).
+
+ python run_trained_agent.py --agent /path/to/model.pth \
+ --n_rollouts 50 --horizon 400 --seed 0 \
+ --dataset_path /path/to/output.hdf5
+"""
+import argparse
+import os
+import json
+import h5py
+import imageio
+import sys
+import time
+import traceback
+import numpy as np
+from copy import deepcopy
+from tqdm import tqdm
+
+import torch
+
+import robomimic
+import robomimic.utils.file_utils as FileUtils
+import robomimic.utils.env_utils as EnvUtils
+import robomimic.utils.torch_utils as TorchUtils
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.obs_utils as ObsUtils
+from robomimic.utils.log_utils import log_warning
+from robomimic.envs.env_base import EnvBase
+from robomimic.envs.wrappers import EnvWrapper
+from robomimic.algo import RolloutPolicy
+from robomimic.scripts.playback_dataset import DEFAULT_CAMERAS
+
+
+def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, return_obs=False, camera_names=None, real=False, rate_measure=None):
+ """
+ Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video,
+ and returns the rollout trajectory.
+
+ Args:
+ policy (instance of RolloutPolicy): policy loaded from a checkpoint
+ env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
+ horizon (int): maximum horizon for the rollout
+ render (bool): whether to render rollout on-screen
+ video_writer (imageio writer): if provided, use to write rollout to video
+ video_skip (int): how often to write video frames
+ return_obs (bool): if True, return possibly high-dimensional observations along the trajectoryu.
+ They are excluded by default because the low-dimensional simulation states should be a minimal
+ representation of the environment.
+ camera_names (list): determines which camera(s) are used for rendering. Pass more than
+ one to output a video with multiple camera views concatenated horizontally.
+ real (bool): if real robot rollout
+ rate_measure: if provided, measure rate of action computation and do not play actions in environment
+
+ Returns:
+ stats (dict): some statistics for the rollout - such as return, horizon, and task success
+ traj (dict): dictionary that corresponds to the rollout trajectory
+ """
+ rollout_timestamp = time.time()
+ assert isinstance(env, EnvBase) or isinstance(env, EnvWrapper)
+ assert isinstance(policy, RolloutPolicy)
+ assert not (render and (video_writer is not None))
+
+ policy.start_episode()
+ obs = env.reset()
+ state_dict = dict()
+ if real:
+ input("ready for next eval? hit enter to continue")
+ else:
+ state_dict = env.get_state()
+ # hack that is necessary for robosuite tasks for deterministic action playback
+ obs = env.reset_to(state_dict)
+
+ results = {}
+ video_count = 0 # video frame counter
+ total_reward = 0.
+ got_exception = False
+ success = env.is_success()["task"]
+ traj = dict(actions=[], rewards=[], dones=[], states=[], initial_state_dict=state_dict)
+ if return_obs:
+ # store observations too
+ traj.update(dict(obs=[], next_obs=[]))
+ try:
+ for step_i in range(horizon):
+ # HACK: some keys on real robot do not have a shape (and then they get frame stacked)
+ for k in obs:
+ if len(obs[k].shape) == 1:
+ obs[k] = obs[k][..., None]
+
+ # get action from policy
+ t1 = time.time()
+ act = policy(ob=obs)
+ t2 = time.time()
+ if real and (not env.base_env.controller_type == "JOINT_IMPEDANCE") and (policy.policy.global_config.algo_name != "diffusion_policy"):
+ # joint impedance actions and diffusion policy actions are absolute in the real world
+ act = np.clip(act, -1., 1.)
+
+ if rate_measure is not None:
+ rate_measure.measure()
+ print("time: {}s".format(t2 - t1))
+ # dummy reward and done
+ r = 0.
+ done = False
+ next_obs = obs
+ else:
+ # play action
+ next_obs, r, done, _ = env.step(act)
+
+ # compute reward
+ total_reward += r
+ success = env.is_success()["task"]
+
+ # visualization
+ if render:
+ env.render(mode="human", camera_name=camera_names[0])
+ if video_writer is not None:
+ if video_count % video_skip == 0:
+ video_img = []
+ for cam_name in camera_names:
+ video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
+ video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
+ video_writer.append_data(video_img)
+ video_count += 1
+
+ # collect transition
+ traj["actions"].append(act)
+ traj["rewards"].append(r)
+ traj["dones"].append(done)
+ if not real:
+ traj["states"].append(state_dict["states"])
+ if return_obs:
+ # Note: We need to "unprocess" the observations to prepare to write them to dataset.
+ # This includes operations like channel swapping and float to uint8 conversion
+ # for saving disk space.
+ traj["obs"].append(ObsUtils.unprocess_obs_dict(obs))
+ traj["next_obs"].append(ObsUtils.unprocess_obs_dict(next_obs))
+
+ # break if done or if success
+ if done or success:
+ break
+
+ # update for next iter
+ obs = deepcopy(next_obs)
+ if not real:
+ state_dict = env.get_state()
+
+ except env.rollout_exceptions as e:
+ print("WARNING: got rollout exception {}".format(e))
+ got_exception = True
+
+ stats = dict(
+ Return=total_reward,
+ Horizon=(step_i + 1),
+ Success_Rate=float(success),
+ Exception_Rate=float(got_exception),
+ time=(time.time() - rollout_timestamp),
+ )
+
+ if return_obs:
+ # convert list of dict to dict of list for obs dictionaries (for convenient writes to hdf5 dataset)
+ traj["obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["obs"])
+ traj["next_obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["next_obs"])
+
+ # list to numpy array
+ for k in traj:
+ if k == "initial_state_dict":
+ continue
+ if isinstance(traj[k], dict):
+ for kp in traj[k]:
+ traj[k][kp] = np.array(traj[k][kp])
+ else:
+ traj[k] = np.array(traj[k])
+
+ return stats, traj
+
+
+def run_trained_agent(args):
+ # some arg checking
+ write_video = (args.video_path is not None)
+ assert not (args.render and write_video) # either on-screen or video but not both
+
+ rate_measure = None
+ if args.hz is not None:
+ import RobotTeleop
+ from RobotTeleop.utils import Rate, RateMeasure, Timers
+ rate_measure = RateMeasure(name="control_rate_measure", freq_threshold=args.hz)
+
+ # load ckpt dict and get algo name for sanity checks
+ algo_name, ckpt_dict = FileUtils.algo_name_from_checkpoint(ckpt_path=args.agent)
+
+ if args.dp_eval_steps is not None:
+ assert algo_name == "diffusion_policy"
+ log_warning("setting @num_inference_steps to {}".format(args.dp_eval_steps))
+
+ # HACK: modify the config, then dump to json again and write to ckpt_dict
+ tmp_config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
+ with tmp_config.values_unlocked():
+ if tmp_config.algo.ddpm.enabled:
+ tmp_config.algo.ddpm.num_inference_timesteps = args.dp_eval_steps
+ elif tmp_config.algo.ddim.enabled:
+ tmp_config.algo.ddim.num_inference_timesteps = args.dp_eval_steps
+ else:
+ raise Exception("should not reach here")
+ ckpt_dict['config'] = tmp_config.dump()
+
+ # device
+ device = TorchUtils.get_torch_device(try_to_use_cuda=True)
+
+ # restore policy
+ policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_dict=ckpt_dict, device=device, verbose=True)
+
+ # read rollout settings
+ rollout_num_episodes = args.n_rollouts
+ rollout_horizon = args.horizon
+ config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
+ if rollout_horizon is None:
+ # read horizon from config
+ rollout_horizon = config.experiment.rollout.horizon
+
+ # HACK: assume absolute actions for now if using diffusion policy on real robot
+ if (algo_name == "diffusion_policy") and EnvUtils.is_real_robot_gprs_env(env_meta=ckpt_dict["env_metadata"]):
+ ckpt_dict["env_metadata"]["env_kwargs"]["absolute_actions"] = True
+
+ # create environment from saved checkpoint
+ env, _ = FileUtils.env_from_checkpoint(
+ ckpt_dict=ckpt_dict,
+ env_name=args.env,
+ render=args.render,
+ render_offscreen=(args.video_path is not None),
+ verbose=True,
+ )
+
+ # Auto-fill camera rendering info if not specified
+ if args.camera_names is None:
+ # We fill in the automatic values
+ env_type = EnvUtils.get_env_type(env=env)
+ args.camera_names = DEFAULT_CAMERAS[env_type]
+ if args.render:
+ # on-screen rendering can only support one camera
+ assert len(args.camera_names) == 1
+
+ is_real_robot = EnvUtils.is_real_robot_env(env=env) or EnvUtils.is_real_robot_gprs_env(env=env)
+ if is_real_robot:
+ # on real robot - log some warnings
+ need_pause = False
+ if "env_name" not in ckpt_dict["env_metadata"]["env_kwargs"]:
+ log_warning("env_name not in checkpoint...proceed with caution...")
+ need_pause = True
+ if ckpt_dict["env_metadata"]["env_name"] != "EnvRealPandaGPRS":
+ # we will load EnvRealPandaGPRS class by default on real robot even if agent was collected with different class
+ log_warning("env name in metadata appears to be class ({}) different from EnvRealPandaGPRS".format(ckpt_dict["env_metadata"]["env_name"]))
+ need_pause = True
+ if need_pause:
+ ans = input("continue? (y/n)")
+ if ans != "y":
+ exit()
+
+ # maybe set seed
+ if args.seed is not None:
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+
+ # maybe create video writer
+ video_writer = None
+ if write_video:
+ video_writer = imageio.get_writer(args.video_path, fps=20)
+
+ # maybe open hdf5 to write rollouts
+ write_dataset = (args.dataset_path is not None)
+ if write_dataset:
+ data_writer = h5py.File(args.dataset_path, "w")
+ data_grp = data_writer.create_group("data")
+ total_samples = 0
+
+ rollout_stats = []
+ for i in tqdm(range(rollout_num_episodes)):
+ try:
+ stats, traj = rollout(
+ policy=policy,
+ env=env,
+ horizon=rollout_horizon,
+ render=args.render,
+ video_writer=video_writer,
+ video_skip=args.video_skip,
+ return_obs=(write_dataset and args.dataset_obs),
+ camera_names=args.camera_names,
+ real=is_real_robot,
+ rate_measure=rate_measure,
+ )
+ except KeyboardInterrupt:
+ if is_real_robot:
+ print("ctrl-C catched, stop execution")
+ print("env rate measure")
+ print(env.rate_measure)
+ ans = input("success? (y / n)")
+ rollout_stats.append((1 if ans == "y" else 0))
+ print("*" * 50)
+ print("have {} success out of {} attempts".format(np.sum(rollout_stats), len(rollout_stats)))
+ print("*" * 50)
+ continue
+ else:
+ sys.exit(0)
+
+ if is_real_robot:
+ print("TERMINATE WITHOUT KEYBOARD INTERRUPT...")
+ ans = input("success? (y / n)")
+ rollout_stats.append((1 if ans == "y" else 0))
+ continue
+ rollout_stats.append(stats)
+
+ if write_dataset:
+ # store transitions
+ ep_data_grp = data_grp.create_group("demo_{}".format(i))
+ ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]))
+ ep_data_grp.create_dataset("states", data=np.array(traj["states"]))
+ ep_data_grp.create_dataset("rewards", data=np.array(traj["rewards"]))
+ ep_data_grp.create_dataset("dones", data=np.array(traj["dones"]))
+ if args.dataset_obs:
+ for k in traj["obs"]:
+ ep_data_grp.create_dataset("obs/{}".format(k), data=np.array(traj["obs"][k]))
+ ep_data_grp.create_dataset("next_obs/{}".format(k), data=np.array(traj["next_obs"][k]))
+
+ # episode metadata
+ if "model" in traj["initial_state_dict"]:
+ ep_data_grp.attrs["model_file"] = traj["initial_state_dict"]["model"] # model xml for this episode
+ ep_data_grp.attrs["num_samples"] = traj["actions"].shape[0] # number of transitions in this episode
+ total_samples += traj["actions"].shape[0]
+
+ rollout_stats = TensorUtils.list_of_flat_dict_to_dict_of_list(rollout_stats)
+ avg_rollout_stats = { k : np.mean(rollout_stats[k]) for k in rollout_stats }
+ avg_rollout_stats["Num_Success"] = np.sum(rollout_stats["Success_Rate"])
+ avg_rollout_stats["Time_Episode"] = np.sum(rollout_stats["time"]) / 60. # total time taken for rollouts in minutes
+ avg_rollout_stats["Num_Episode"] = len(rollout_stats["Success_Rate"]) # number of episodes attempted
+ print("Average Rollout Stats")
+ stats_json = json.dumps(avg_rollout_stats, indent=4)
+ print(stats_json)
+ if args.json_path is not None:
+ json_f = open(args.json_path, "w")
+ json_f.write(stats_json)
+ json_f.close()
+
+ if write_video:
+ video_writer.close()
+
+ if write_dataset:
+ # global metadata
+ data_grp.attrs["total"] = total_samples
+ data_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4) # environment info
+ data_writer.close()
+ print("Wrote dataset trajectories to {}".format(args.dataset_path))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # Path to trained model
+ parser.add_argument(
+ "--agent",
+ type=str,
+ required=True,
+ help="path to saved checkpoint pth file",
+ )
+
+ # number of rollouts
+ parser.add_argument(
+ "--n_rollouts",
+ type=int,
+ default=27,
+ help="number of rollouts",
+ )
+
+ # maximum horizon of rollout, to override the one stored in the model checkpoint
+ parser.add_argument(
+ "--horizon",
+ type=int,
+ default=None,
+ help="(optional) override maximum horizon of rollout from the one in the checkpoint",
+ )
+
+ # Env Name (to override the one stored in model checkpoint)
+ parser.add_argument(
+ "--env",
+ type=str,
+ default=None,
+ help="(optional) override name of env from the one in the checkpoint, and use\
+ it for rollouts",
+ )
+
+ # Whether to render rollouts to screen
+ parser.add_argument(
+ "--render",
+ action='store_true',
+ help="on-screen rendering",
+ )
+
+ # Dump a video of the rollouts to the specified path
+ parser.add_argument(
+ "--video_path",
+ type=str,
+ default=None,
+ help="(optional) render rollouts to this video file path",
+ )
+
+ # How often to write video frames during the rollout
+ parser.add_argument(
+ "--video_skip",
+ type=int,
+ default=5,
+ help="render frames to video every n steps",
+ )
+
+ # camera names to render
+ parser.add_argument(
+ "--camera_names",
+ type=str,
+ nargs='+',
+ default=None,
+ help="(optional) camera name(s) to use for rendering on-screen or to video",
+ )
+
+ # If provided, an hdf5 file will be written with the rollout data
+ parser.add_argument(
+ "--dataset_path",
+ type=str,
+ default=None,
+ help="(optional) if provided, an hdf5 file will be written at this path with the rollout data",
+ )
+
+ # If True and @dataset_path is supplied, will write possibly high-dimensional observations to dataset.
+ parser.add_argument(
+ "--dataset_obs",
+ action='store_true',
+ help="include possibly high-dimensional observations in output dataset hdf5 file (by default,\
+ observations are excluded and only simulator states are saved)",
+ )
+
+ # for seeding before starting rollouts
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="(optional) set seed for rollouts",
+ )
+
+ # Dump a json of the rollout results stats to the specified path
+ parser.add_argument(
+ "--json_path",
+ type=str,
+ default=None,
+ help="(optional) dump a json of the rollout results stats to the specified path",
+ )
+
+ # Dump a file with the error traceback at this path. Only created if run fails with an error.
+ parser.add_argument(
+ "--error_path",
+ type=str,
+ default=None,
+ help="(optional) dump a file with the error traceback at this path. Only created if run fails with an error.",
+ )
+
+ # TODO: clean up this arg
+ # If provided, do not run actions in env, and instead just measure the rate of action computation
+ parser.add_argument(
+ "--hz",
+ type=int,
+ default=None,
+ help="If provided, do not run actions in env, and instead just measure the rate of action computation and raise warnings if it dips below this threshold",
+ )
+
+ # TODO: clean up this arg
+ # If provided, set num_inference_timesteps explicitly for diffusion policy evaluation
+ parser.add_argument(
+ "--dp_eval_steps",
+ type=int,
+ default=None,
+ help="If provided, set num_inference_timesteps explicitly for diffusion policy evaluation",
+ )
+
+ args = parser.parse_args()
+ res_str = None
+ try:
+ run_trained_agent(args)
+ except Exception as e:
+ res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc())
+ if args.error_path is not None:
+ # write traceback to file
+ f = open(args.error_path, "w")
+ f.write(res_str)
+ f.close()
+ raise e
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/setup_macros.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/setup_macros.py
new file mode 100644
index 0000000000000000000000000000000000000000..92c472712078684a84e9ae624b13cd7d9b6c953c
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/setup_macros.py
@@ -0,0 +1,32 @@
+"""
+This script sets up a private macros file.
+
+The private macros file (macros_private.py) is not tracked by git,
+allowing user-specific settings that are not tracked by git.
+
+This script checks if macros_private.py exists.
+If applicable, it creates the private macros at robomimic/macros_private.py
+"""
+
+import os
+import robomimic
+import shutil
+
+if __name__ == "__main__":
+ base_path = robomimic.__path__[0]
+ macros_path = os.path.join(base_path, "macros.py")
+ macros_private_path = os.path.join(base_path, "macros_private.py")
+
+ if not os.path.exists(macros_path):
+ print("{} does not exist! Aborting...".format(macros_path))
+
+ if os.path.exists(macros_private_path):
+ ans = input("{} already exists! \noverwrite? (y/n)\n".format(macros_private_path))
+
+ if ans == "y":
+ print("REMOVING")
+ else:
+ exit()
+
+ shutil.copyfile(macros_path, macros_private_path)
+ print("copied {}\nto {}".format(macros_path, macros_private_path))
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/split_train_val.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/split_train_val.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d0502ea81dc238e21e0211c1c71c803f0b1b00d
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/split_train_val.py
@@ -0,0 +1,105 @@
+"""
+Script for splitting a dataset hdf5 file into training and validation trajectories.
+
+Args:
+ dataset (str): path to hdf5 dataset
+
+ filter_key (str): if provided, split the subset of trajectories
+ in the file that correspond to this filter key into a training
+ and validation set of trajectories, instead of splitting the
+ full set of trajectories
+
+ ratio (float): validation ratio, in (0, 1). Defaults to 0.1, which is 10%.
+
+Example usage:
+ python split_train_val.py --dataset /path/to/demo.hdf5 --ratio 0.1
+"""
+
+import argparse
+import h5py
+import numpy as np
+
+from robomimic.utils.file_utils import create_hdf5_filter_key
+
+
+def split_train_val_from_hdf5(hdf5_path, val_ratio=0.1, filter_key=None):
+ """
+ Splits data into training set and validation set from HDF5 file.
+
+ Args:
+ hdf5_path (str): path to the hdf5 file
+ to load the transitions from
+
+ val_ratio (float): ratio of validation demonstrations to all demonstrations
+
+ filter_key (str): if provided, split the subset of demonstration keys stored
+ under mask/@filter_key instead of the full set of demonstrations
+ """
+
+ # retrieve number of demos
+ f = h5py.File(hdf5_path, "r")
+ if filter_key is not None:
+ print("using filter key: {}".format(filter_key))
+ demos = sorted([elem.decode("utf-8") for elem in np.array(f["mask/{}".format(filter_key)])])
+ else:
+ demos = sorted(list(f["data"].keys()))
+ num_demos = len(demos)
+ f.close()
+
+ # get random split
+ num_demos = len(demos)
+ num_val = int(val_ratio * num_demos)
+ mask = np.zeros(num_demos)
+ mask[:num_val] = 1.
+ np.random.shuffle(mask)
+ mask = mask.astype(int)
+ train_inds = (1 - mask).nonzero()[0]
+ valid_inds = mask.nonzero()[0]
+ train_keys = [demos[i] for i in train_inds]
+ valid_keys = [demos[i] for i in valid_inds]
+ print("{} validation demonstrations out of {} total demonstrations.".format(num_val, num_demos))
+
+ # pass mask to generate split
+ name_1 = "train"
+ name_2 = "valid"
+ if filter_key is not None:
+ name_1 = "{}_{}".format(filter_key, name_1)
+ name_2 = "{}_{}".format(filter_key, name_2)
+
+ train_lengths = create_hdf5_filter_key(hdf5_path=hdf5_path, demo_keys=train_keys, key_name=name_1)
+ valid_lengths = create_hdf5_filter_key(hdf5_path=hdf5_path, demo_keys=valid_keys, key_name=name_2)
+
+ print("Total number of train samples: {}".format(np.sum(train_lengths)))
+ print("Average number of train samples {}".format(np.mean(train_lengths)))
+
+ print("Total number of valid samples: {}".format(np.sum(valid_lengths)))
+ print("Average number of valid samples {}".format(np.mean(valid_lengths)))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="path to hdf5 dataset",
+ )
+ parser.add_argument(
+ "--filter_key",
+ type=str,
+ default=None,
+ help="if provided, split the subset of trajectories in the file that correspond to\
+ this filter key into a training and validation set of trajectories, instead of\
+ splitting the full set of trajectories",
+ )
+ parser.add_argument(
+ "--ratio",
+ type=float,
+ default=0.1,
+ help="validation ratio, in (0, 1)"
+ )
+ args = parser.parse_args()
+
+ # seed to make sure results are consistent
+ np.random.seed(0)
+
+ split_train_val_from_hdf5(args.dataset, val_ratio=args.ratio, filter_key=args.filter_key)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/scripts/train.py b/phantom/submodules/phantom-robomimic/robomimic/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b101984e2972395175e1b0c21563b9ab15cba2d
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/scripts/train.py
@@ -0,0 +1,599 @@
+"""
+The main entry point for training policies.
+
+Args:
+ config (str): path to a config json that will be used to override the default settings.
+ If omitted, default settings are used. This is the preferred way to run experiments.
+
+ algo (str): name of the algorithm to run. Only needs to be provided if @config is not
+ provided.
+
+ name (str): if provided, override the experiment name defined in the config
+
+ dataset (str): if provided, override the dataset path defined in the config
+
+ debug (bool): set this flag to run a quick training run for debugging purposes
+"""
+
+import argparse
+import json
+import numpy as np
+import time
+import os
+import shutil
+import psutil
+import sys
+import socket
+import traceback
+
+from collections import OrderedDict
+
+import torch
+from torch.utils.data import DataLoader
+
+import robomimic
+import robomimic.macros as Macros
+import robomimic.utils.train_utils as TrainUtils
+import robomimic.utils.torch_utils as TorchUtils
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.env_utils as EnvUtils
+import robomimic.utils.file_utils as FileUtils
+from robomimic.config import config_factory
+from robomimic.algo import algo_factory, RolloutPolicy
+from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings
+
+
+def train(config, device, auto_remove_exp=False):
+ """
+ Train a model using the algorithm.
+ """
+
+ # time this run
+ start_time = time.time()
+
+ # first set seeds
+ np.random.seed(config.train.seed)
+ torch.manual_seed(config.train.seed)
+
+ torch.set_num_threads(2)
+
+ print("\n============= New Training Run with Config =============")
+ print(config)
+ print("")
+ log_dir, ckpt_dir, video_dir = TrainUtils.get_exp_dir(config, auto_remove_exp_dir=auto_remove_exp)
+
+ if config.experiment.logging.terminal_output_to_txt:
+ # log stdout and stderr to a text file
+ logger = PrintLogger(os.path.join(log_dir, 'log.txt'))
+ sys.stdout = logger
+ sys.stderr = logger
+
+ # read config to set up metadata for observation modalities (e.g. detecting rgb observations)
+ ObsUtils.initialize_obs_utils_with_config(config)
+
+ # make sure the dataset exists
+ if isinstance(config.train.data, str):
+ dataset_path = os.path.expandvars(os.path.expanduser(config.train.data))
+ else:
+ eval_dataset_cfg = config.train.data[0]
+ dataset_path = os.path.expandvars(os.path.expanduser(eval_dataset_cfg["path"]))
+ ds_format = config.train.data_format
+ if not os.path.exists(dataset_path):
+ raise Exception("Dataset at provided path {} not found!".format(dataset_path))
+
+ # load basic metadata from training file
+ print("\n============= Loaded Environment Metadata =============")
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path, ds_format=ds_format)
+
+ # update env meta if applicable
+ from robomimic.utils.script_utils import deep_update
+ deep_update(env_meta, config.experiment.env_meta_update_dict)
+
+ shape_meta = FileUtils.get_shape_metadata_from_dataset(
+ dataset_path=dataset_path,
+ action_keys=config.train.action_keys,
+ all_obs_keys=config.all_obs_keys,
+ ds_format=ds_format,
+ verbose=True
+ )
+
+ if config.experiment.env is not None:
+ env_meta["env_name"] = config.experiment.env
+ print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30)
+
+ # create environment
+ envs = OrderedDict()
+ if config.experiment.rollout.enabled:
+ # create environments for validation runs
+ env_names = [env_meta["env_name"]]
+
+ if config.experiment.additional_envs is not None:
+ for name in config.experiment.additional_envs:
+ env_names.append(name)
+
+ for env_name in env_names:
+ env = EnvUtils.create_env_from_metadata(
+ env_meta=env_meta,
+ env_name=env_name,
+ render=config.experiment.render,
+ render_offscreen=config.experiment.render_video,
+ use_image_obs=shape_meta["use_images"],
+ use_depth_obs=shape_meta["use_depths"],
+ )
+ env = EnvUtils.wrap_env_from_config(env, config=config) # apply environment warpper, if applicable
+ envs[env.name] = env
+ print(envs[env.name])
+
+ print("")
+
+ # setup for a new training run
+ data_logger = DataLogger(
+ log_dir,
+ config,
+ log_tb=config.experiment.logging.log_tb,
+ log_wandb=config.experiment.logging.log_wandb,
+ )
+ model = algo_factory(
+ algo_name=config.algo_name,
+ config=config,
+ obs_key_shapes=shape_meta["all_shapes"],
+ ac_dim=shape_meta["ac_dim"],
+ device=device,
+ )
+
+ # save the config as a json file
+ with open(os.path.join(log_dir, '..', 'config.json'), 'w') as outfile:
+ json.dump(config, outfile, indent=4)
+
+ print("\n============= Model Summary =============")
+ print(model) # print model summary
+ print("")
+
+ # load training data
+ trainset, validset = TrainUtils.load_data_for_training(
+ config, obs_keys=shape_meta["all_obs_keys"])
+ train_sampler = trainset.get_dataset_sampler()
+ print("\n============= Training Dataset =============")
+ print(trainset)
+ print("")
+ if validset is not None:
+ print("\n============= Validation Dataset =============")
+ print(validset)
+ print("")
+
+ # maybe retreve statistics for normalizing observations
+ obs_normalization_stats = None
+ if config.train.hdf5_normalize_obs:
+ obs_normalization_stats = trainset.get_obs_normalization_stats()
+
+ # maybe retreve statistics for normalizing actions
+ action_normalization_stats = trainset.get_action_normalization_stats()
+
+ # initialize data loaders
+ train_loader = DataLoader(
+ dataset=trainset,
+ sampler=train_sampler,
+ batch_size=config.train.batch_size,
+ shuffle=(train_sampler is None),
+ num_workers=config.train.num_data_workers,
+ drop_last=True
+ )
+
+ if config.experiment.validate:
+ # cap num workers for validation dataset at 1
+ num_workers = min(config.train.num_data_workers, 1)
+ valid_sampler = validset.get_dataset_sampler()
+ valid_loader = DataLoader(
+ dataset=validset,
+ sampler=valid_sampler,
+ batch_size=config.train.batch_size,
+ shuffle=(valid_sampler is None),
+ num_workers=num_workers,
+ drop_last=True
+ )
+ else:
+ valid_loader = None
+
+ # print all warnings before training begins
+ print("*" * 50)
+ print("Warnings generated by robomimic have been duplicated here (from above) for convenience. Please check them carefully.")
+ flush_warnings()
+ print("*" * 50)
+ print("")
+
+ # main training loop
+ best_valid_loss = None
+ best_return = {k: -np.inf for k in envs} if config.experiment.rollout.enabled else None
+ best_success_rate = {k: -1. for k in envs} if config.experiment.rollout.enabled else None
+ last_ckpt_time = time.time()
+
+ need_sync_results = (Macros.RESULTS_SYNC_PATH_ABS is not None)
+ if need_sync_results:
+ # these paths will be updated after each evaluation
+ best_ckpt_path_synced = None
+ best_video_path_synced = None
+ last_ckpt_path_synced = None
+ last_video_path_synced = None
+ log_dir_path_synced = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "logs")
+
+ # number of learning steps per epoch (defaults to a full dataset pass)
+ train_num_steps = config.experiment.epoch_every_n_steps
+ valid_num_steps = config.experiment.validation_epoch_every_n_steps
+
+ for epoch in range(1, config.train.num_epochs + 1): # epoch numbers start at 1
+ step_log = TrainUtils.run_epoch(
+ model=model,
+ data_loader=train_loader,
+ epoch=epoch,
+ num_steps=train_num_steps,
+ obs_normalization_stats=obs_normalization_stats,
+ )
+ model.on_epoch_end(epoch)
+
+ # setup checkpoint path
+ epoch_ckpt_name = "model_epoch_{}".format(epoch)
+
+ # check for recurring checkpoint saving conditions
+ should_save_ckpt = False
+ if config.experiment.save.enabled:
+ time_check = (config.experiment.save.every_n_seconds is not None) and \
+ (time.time() - last_ckpt_time > config.experiment.save.every_n_seconds)
+ epoch_check = (config.experiment.save.every_n_epochs is not None) and \
+ (epoch > 0) and (epoch % config.experiment.save.every_n_epochs == 0)
+ epoch_list_check = (epoch in config.experiment.save.epochs)
+ should_save_ckpt = (time_check or epoch_check or epoch_list_check)
+ ckpt_reason = None
+ if should_save_ckpt:
+ last_ckpt_time = time.time()
+ ckpt_reason = "time"
+
+ print("Train Epoch {}".format(epoch))
+ print(json.dumps(step_log, sort_keys=True, indent=4))
+ for k, v in step_log.items():
+ if k.startswith("Time_"):
+ data_logger.record("Timing_Stats/Train_{}".format(k[5:]), v, epoch)
+ else:
+ data_logger.record("Train/{}".format(k), v, epoch)
+
+ # Evaluate the model on validation set
+ if config.experiment.validate:
+ with torch.no_grad():
+ step_log = TrainUtils.run_epoch(model=model, data_loader=valid_loader, epoch=epoch, validate=True, num_steps=valid_num_steps)
+ for k, v in step_log.items():
+ if k.startswith("Time_"):
+ data_logger.record("Timing_Stats/Valid_{}".format(k[5:]), v, epoch)
+ else:
+ data_logger.record("Valid/{}".format(k), v, epoch)
+
+ print("Validation Epoch {}".format(epoch))
+ print(json.dumps(step_log, sort_keys=True, indent=4))
+
+ # save checkpoint if achieve new best validation loss
+ valid_check = "Loss" in step_log
+ if valid_check and (best_valid_loss is None or (step_log["Loss"] <= best_valid_loss)):
+ best_valid_loss = step_log["Loss"]
+ if config.experiment.save.enabled and config.experiment.save.on_best_validation:
+ epoch_ckpt_name += "_best_validation_{}".format(best_valid_loss)
+ should_save_ckpt = True
+ ckpt_reason = "valid" if ckpt_reason is None else ckpt_reason
+
+ # Evaluate the model by by running rollouts
+
+ # do rollouts at fixed rate or if it's time to save a new ckpt
+ video_paths = None
+ rollout_check = (epoch % config.experiment.rollout.rate == 0) or (should_save_ckpt and ckpt_reason == "time")
+ did_rollouts = False
+ if config.experiment.rollout.enabled and (epoch > config.experiment.rollout.warmstart) and rollout_check:
+
+ # wrap model as a RolloutPolicy to prepare for rollouts
+ rollout_model = RolloutPolicy(
+ model,
+ obs_normalization_stats=obs_normalization_stats,
+ action_normalization_stats=action_normalization_stats,
+ )
+
+ num_episodes = config.experiment.rollout.n
+ all_rollout_logs, video_paths = TrainUtils.rollout_with_stats(
+ policy=rollout_model,
+ envs=envs,
+ horizon=config.experiment.rollout.horizon,
+ use_goals=config.use_goals,
+ num_episodes=num_episodes,
+ render=False,
+ video_dir=video_dir if config.experiment.render_video else None,
+ epoch=epoch,
+ video_skip=config.experiment.get("video_skip", 5),
+ terminate_on_success=config.experiment.rollout.terminate_on_success,
+ )
+
+ # summarize results from rollouts to tensorboard and terminal
+ for env_name in all_rollout_logs:
+ rollout_logs = all_rollout_logs[env_name]
+ for k, v in rollout_logs.items():
+ if k.startswith("Time_"):
+ data_logger.record("Timing_Stats/Rollout_{}_{}".format(env_name, k[5:]), v, epoch)
+ else:
+ data_logger.record("Rollout/{}/{}".format(k, env_name), v, epoch, log_stats=True)
+
+ print("\nEpoch {} Rollouts took {}s (avg) with results:".format(epoch, rollout_logs["time"]))
+ print('Env: {}'.format(env_name))
+ print(json.dumps(rollout_logs, sort_keys=True, indent=4))
+
+ # checkpoint and video saving logic
+ updated_stats = TrainUtils.should_save_from_rollout_logs(
+ all_rollout_logs=all_rollout_logs,
+ best_return=best_return,
+ best_success_rate=best_success_rate,
+ epoch_ckpt_name=epoch_ckpt_name,
+ save_on_best_rollout_return=config.experiment.save.on_best_rollout_return,
+ save_on_best_rollout_success_rate=config.experiment.save.on_best_rollout_success_rate,
+ )
+ best_return = updated_stats["best_return"]
+ best_success_rate = updated_stats["best_success_rate"]
+ epoch_ckpt_name = updated_stats["epoch_ckpt_name"]
+ should_save_ckpt = (config.experiment.save.enabled and updated_stats["should_save_ckpt"]) or should_save_ckpt
+ if updated_stats["ckpt_reason"] is not None:
+ ckpt_reason = updated_stats["ckpt_reason"]
+ did_rollouts = True
+
+ # Only keep saved videos if the ckpt should be saved (but not because of validation score)
+ should_save_video = (should_save_ckpt and (ckpt_reason != "valid")) or config.experiment.keep_all_videos
+ if video_paths is not None and not should_save_video:
+ for env_name in video_paths:
+ os.remove(video_paths[env_name])
+
+ # Save model checkpoints based on conditions (success rate, validation loss, etc)
+ if should_save_ckpt:
+ TrainUtils.save_model(
+ model=model,
+ config=config,
+ env_meta=env_meta,
+ shape_meta=shape_meta,
+ ckpt_path=os.path.join(ckpt_dir, epoch_ckpt_name + ".pth"),
+ obs_normalization_stats=obs_normalization_stats,
+ action_normalization_stats=action_normalization_stats,
+ )
+
+ # maybe sync some results back to scratch space (only if rollouts happened)
+ if did_rollouts and need_sync_results:
+ print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS))
+
+ # get best and latest model checkpoints and videos
+ best_ckpt_path_to_sync, best_video_path_to_sync, best_epoch_to_sync = TrainUtils.get_model_from_output_folder(
+ models_path=ckpt_dir,
+ videos_path=video_dir if config.experiment.render_video else None,
+ best=True,
+ )
+ last_ckpt_path_to_sync, last_video_path_to_sync, last_epoch_to_sync = TrainUtils.get_model_from_output_folder(
+ models_path=ckpt_dir,
+ videos_path=video_dir if config.experiment.render_video else None,
+ last=True,
+ )
+
+ # clear last files that we synced over
+ if best_ckpt_path_synced is not None:
+ os.remove(best_ckpt_path_synced)
+ if last_ckpt_path_synced is not None:
+ os.remove(last_ckpt_path_synced)
+ if best_video_path_synced is not None:
+ os.remove(best_video_path_synced)
+ if last_video_path_synced is not None:
+ os.remove(last_video_path_synced)
+ if os.path.exists(log_dir_path_synced):
+ shutil.rmtree(log_dir_path_synced)
+
+ # set write paths and sync new files over
+ best_success_rate_for_sync = float(best_ckpt_path_to_sync.split("success_")[-1][:-4])
+ best_ckpt_path_synced = os.path.join(
+ Macros.RESULTS_SYNC_PATH_ABS,
+ os.path.basename(best_ckpt_path_to_sync)[:-4] + "_best.pth",
+ )
+ shutil.copyfile(best_ckpt_path_to_sync, best_ckpt_path_synced)
+ last_ckpt_path_synced = os.path.join(
+ Macros.RESULTS_SYNC_PATH_ABS,
+ os.path.basename(last_ckpt_path_to_sync)[:-4] + "_last.pth",
+ )
+ shutil.copyfile(last_ckpt_path_to_sync, last_ckpt_path_synced)
+ if config.experiment.render_video:
+ best_video_path_synced = os.path.join(
+ Macros.RESULTS_SYNC_PATH_ABS,
+ os.path.basename(best_video_path_to_sync)[:-4] + "_best_{}.mp4".format(best_success_rate_for_sync),
+ )
+ shutil.copyfile(best_video_path_to_sync, best_video_path_synced)
+ last_video_path_synced = os.path.join(
+ Macros.RESULTS_SYNC_PATH_ABS,
+ os.path.basename(last_video_path_to_sync)[:-4] + "_last.mp4",
+ )
+ shutil.copyfile(last_video_path_to_sync, last_video_path_synced)
+ # sync logs dir
+ shutil.copytree(log_dir, log_dir_path_synced)
+ # sync config json
+ shutil.copyfile(
+ os.path.join(log_dir, '..', 'config.json'),
+ os.path.join(Macros.RESULTS_SYNC_PATH_ABS, 'config.json')
+ )
+
+ # Finally, log memory usage in MB
+ process = psutil.Process(os.getpid())
+ mem_usage = int(process.memory_info().rss / 1000000)
+ data_logger.record("System/RAM Usage (MB)", mem_usage, epoch)
+ print("\nEpoch {} Memory Usage: {} MB\n".format(epoch, mem_usage))
+
+ # terminate logging
+ data_logger.close()
+
+ # sync logs after closing data logger to make sure everything was transferred
+ if need_sync_results:
+ print("Sync results back to sync path: {}".format(Macros.RESULTS_SYNC_PATH_ABS))
+ # sync logs dir
+ if os.path.exists(log_dir_path_synced):
+ shutil.rmtree(log_dir_path_synced)
+ shutil.copytree(log_dir, log_dir_path_synced)
+
+ # collect important statistics
+ important_stats = dict()
+ prefix = "Rollout/Success_Rate/"
+ exception_prefix = "Rollout/Exception_Rate/"
+ for k in data_logger._data:
+ if k.startswith(prefix):
+ suffix = k[len(prefix):]
+ stats = data_logger.get_stats(k)
+ important_stats["{}-max".format(suffix)] = stats["max"]
+ important_stats["{}-mean".format(suffix)] = stats["mean"]
+ elif k.startswith(exception_prefix):
+ suffix = k[len(exception_prefix):]
+ stats = data_logger.get_stats(k)
+ important_stats["{}-exception-rate-max".format(suffix)] = stats["max"]
+ important_stats["{}-exception-rate-mean".format(suffix)] = stats["mean"]
+
+ # add in time taken
+ important_stats["time spent (hrs)"] = "{:.2f}".format((time.time() - start_time) / 3600.)
+
+ # write stats to disk
+ json_file_path = os.path.join(log_dir, "important_stats.json")
+ with open(json_file_path, 'w') as f:
+ # preserve original key ordering
+ json.dump(important_stats, f, sort_keys=False, indent=4)
+
+ return important_stats
+
+
+def main(args):
+
+ if args.config is not None:
+ ext_cfg = json.load(open(args.config, 'r'))
+ config = config_factory(ext_cfg["algo_name"])
+ # update config with external json - this will throw errors if
+ # the external config has keys not present in the base algo config
+ with config.values_unlocked():
+ config.update(ext_cfg)
+ else:
+ config = config_factory(args.algo)
+
+ if args.dataset is not None:
+ config.train.data = [dict(path=args.dataset)]
+
+ if args.name is not None:
+ config.experiment.name = args.name
+
+ if args.output is not None:
+ config.train.output_dir = args.output
+
+ # get torch device
+ device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda)
+
+ # maybe modify config for debugging purposes
+ if args.debug:
+ Macros.DEBUG = True
+
+ # shrink length of training to test whether this run is likely to crash
+ config.unlock()
+ config.lock_keys()
+
+ # train and validate (if enabled) for 3 gradient steps, for 2 epochs
+ config.experiment.epoch_every_n_steps = 3
+ config.experiment.validation_epoch_every_n_steps = 3
+ config.train.num_epochs = 2
+
+ # if rollouts are enabled, try 2 rollouts at end of each epoch, with 10 environment steps
+ config.experiment.rollout.rate = 1
+ config.experiment.rollout.n = 2
+ config.experiment.rollout.horizon = 10
+
+ # send output to a temporary directory
+ config.train.output_dir = "/tmp/tmp_trained_models"
+
+ # lock config to prevent further modifications and ensure missing keys raise errors
+ config.lock()
+
+ # catch error during training and print it
+ res_str = "finished run successfully!"
+ important_stats = None
+ try:
+ important_stats = train(config, device=device, auto_remove_exp=args.auto_remove_exp)
+ except Exception as e:
+ res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc())
+ print(res_str)
+ if important_stats is not None:
+ important_stats = json.dumps(important_stats, indent=4)
+ print("\nRollout Success Rate Stats")
+ print(important_stats)
+
+ # maybe sync important stats back
+ if Macros.RESULTS_SYNC_PATH_ABS is not None:
+ json_file_path = os.path.join(Macros.RESULTS_SYNC_PATH_ABS, "important_stats.json")
+ with open(json_file_path, 'w') as f:
+ # preserve original key ordering
+ json.dump(important_stats, f, sort_keys=False, indent=4)
+
+ # maybe give slack notification
+ if Macros.SLACK_TOKEN is not None:
+ from robomimic.scripts.give_slack_notification import give_slack_notif
+ msg = "Completed the following training run!\nHostname: {}\nExperiment Name: {}\n".format(socket.gethostname(), config.experiment.name)
+ msg += "```{}```".format(res_str)
+ if important_stats is not None:
+ msg += "\nRollout Success Rate Stats"
+ msg += "\n```{}```".format(important_stats)
+ give_slack_notif(msg)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # External config file that overwrites default config
+ parser.add_argument(
+ "--config",
+ type=str,
+ default=None,
+ help="(optional) path to a config json that will be used to override the default settings. \
+ If omitted, default settings are used. This is the preferred way to run experiments.",
+ )
+
+ # Algorithm Name
+ parser.add_argument(
+ "--algo",
+ type=str,
+ help="(optional) name of algorithm to run. Only needs to be provided if --config is not provided",
+ )
+
+ # Experiment Name (for tensorboard, saving models, etc.)
+ parser.add_argument(
+ "--name",
+ type=str,
+ default=None,
+ help="(optional) if provided, override the experiment name defined in the config",
+ )
+
+ # Dataset path, to override the one in the config
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default=None,
+ help="(optional) if provided, override the dataset path defined in the config",
+ )
+
+ # Output path, to override the one in the config
+ parser.add_argument(
+ "--output",
+ type=str,
+ default=None,
+ help="(optional) if provided, override the output folder path defined in the config",
+ )
+
+ # force delete the experiment folder if it exists
+ parser.add_argument(
+ "--auto-remove-exp",
+ action='store_true',
+ help="force delete the experiment folder if it exists"
+ )
+
+ # debug mode
+ parser.add_argument(
+ "--debug",
+ action='store_true',
+ help="set this flag to run a quick training run for debugging purposes"
+ )
+
+ args = parser.parse_args()
+ main(args)
+
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/__init__.py b/phantom/submodules/phantom-robomimic/robomimic/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/action_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/action_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac974d50c5fab46b85cd8c3bb76d8e05a4f56aba
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/action_utils.py
@@ -0,0 +1,35 @@
+
+from typing import Union, Sequence, Dict, Optional, Tuple
+
+from copy import deepcopy
+from collections import OrderedDict
+import functools
+
+import numpy as np
+
+
+def action_dict_to_vector(
+ action_dict: Dict[str, np.ndarray],
+ action_keys: Optional[Sequence[str]]=None) -> np.ndarray:
+ if action_keys is None:
+ action_keys = list(action_dict.keys())
+ actions = [action_dict[k] for k in action_keys]
+
+ action_vec = np.concatenate(actions, axis=-1)
+ return action_vec
+
+
+def vector_to_action_dict(
+ action: np.ndarray,
+ action_shapes: Dict[str, Tuple[int]],
+ action_keys: Sequence[str]) -> Dict[str, np.ndarray]:
+ action_dict = dict()
+ start_idx = 0
+ for key in action_keys:
+ this_act_shape = action_shapes[key]
+ this_act_dim = np.prod(this_act_shape)
+ end_idx = start_idx + this_act_dim
+ action_dict[key] = action[...,start_idx:end_idx].reshape(
+ action.shape[:-1]+this_act_shape)
+ start_idx = end_idx
+ return action_dict
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/dataset.py b/phantom/submodules/phantom-robomimic/robomimic/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d429c7a46d09767f8b1946765e599bbc5667da3
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/dataset.py
@@ -0,0 +1,1134 @@
+"""
+This file contains Dataset classes that are used by torch dataloaders
+to fetch batches from hdf5 files.
+"""
+import os
+import h5py
+import numpy as np
+from copy import deepcopy
+from contextlib import contextmanager
+from collections import OrderedDict
+
+import torch.utils.data
+
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.action_utils as AcUtils
+import robomimic.utils.log_utils as LogUtils
+
+
+class SequenceDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ hdf5_path,
+ obs_keys,
+ action_keys,
+ dataset_keys,
+ action_config,
+ frame_stack=1,
+ seq_length=1,
+ pad_frame_stack=True,
+ pad_seq_length=True,
+ get_pad_mask=False,
+ goal_mode=None,
+ hdf5_cache_mode=None,
+ hdf5_use_swmr=True,
+ hdf5_normalize_obs=False,
+ filter_by_attribute=None,
+ load_next_obs=True,
+ ):
+ """
+ Dataset class for fetching sequences of experience.
+ Length of the fetched sequence is equal to (@frame_stack - 1 + @seq_length)
+
+ Args:
+ hdf5_path (str): path to hdf5
+
+ obs_keys (tuple, list): keys to observation items (image, object, etc) to be fetched from the dataset
+
+ action_config (dict): TODO
+
+ dataset_keys (tuple, list): keys to dataset items (actions, rewards, etc) to be fetched from the dataset
+
+ frame_stack (int): numbers of stacked frames to fetch. Defaults to 1 (single frame).
+
+ seq_length (int): length of sequences to sample. Defaults to 1 (single frame).
+
+ pad_frame_stack (int): whether to pad sequence for frame stacking at the beginning of a demo. This
+ ensures that partial frame stacks are observed, such as (s_0, s_0, s_0, s_1). Otherwise, the
+ first frame stacked observation would be (s_0, s_1, s_2, s_3).
+
+ pad_seq_length (int): whether to pad sequence for sequence fetching at the end of a demo. This
+ ensures that partial sequences at the end of a demonstration are observed, such as
+ (s_{T-1}, s_{T}, s_{T}, s_{T}). Otherwise, the last sequence provided would be
+ (s_{T-3}, s_{T-2}, s_{T-1}, s_{T}).
+
+ get_pad_mask (bool): if True, also provide padding masks as part of the batch. This can be
+ useful for masking loss functions on padded parts of the data.
+
+ goal_mode (str): either "last" or None. Defaults to None, which is to not fetch goals
+
+ hdf5_cache_mode (str): one of ["all", "low_dim", or None]. Set to "all" to cache entire hdf5
+ in memory - this is by far the fastest for data loading. Set to "low_dim" to cache all
+ non-image data. Set to None to use no caching - in this case, every batch sample is
+ retrieved via file i/o. You should almost never set this to None, even for large
+ image datasets.
+
+ hdf5_use_swmr (bool): whether to use swmr feature when opening the hdf5 file. This ensures
+ that multiple Dataset instances can all access the same hdf5 file without problems.
+
+ hdf5_normalize_obs (bool): if True, normalize observations by computing the mean observation
+ and std of each observation (in each dimension and modality), and normalizing to unit
+ mean and variance in each dimension.
+
+ filter_by_attribute (str): if provided, use the provided filter key to look up a subset of
+ demonstrations to load
+
+ load_next_obs (bool): whether to load next_obs from the dataset
+ """
+ super(SequenceDataset, self).__init__()
+
+ self.hdf5_path = os.path.expandvars(os.path.expanduser(hdf5_path))
+ self.hdf5_use_swmr = hdf5_use_swmr
+ self.hdf5_normalize_obs = hdf5_normalize_obs
+ self._hdf5_file = None
+
+ assert hdf5_cache_mode in ["all", "low_dim", None]
+ self.hdf5_cache_mode = hdf5_cache_mode
+
+ self.load_next_obs = load_next_obs
+ self.filter_by_attribute = filter_by_attribute
+
+ # get all keys that needs to be fetched
+ self.obs_keys = tuple(obs_keys)
+ self.action_keys = tuple(action_keys)
+ self.dataset_keys = tuple(dataset_keys)
+ # add action keys to dataset keys
+ if self.action_keys is not None:
+ self.dataset_keys = tuple(set(self.dataset_keys).union(set(self.action_keys)))
+
+ self.action_config = action_config
+
+ self.n_frame_stack = frame_stack
+ assert self.n_frame_stack >= 1
+
+ self.seq_length = seq_length
+ assert self.seq_length >= 1
+
+ self.goal_mode = goal_mode
+ if self.goal_mode is not None:
+ assert self.goal_mode in ["last"]
+ if not self.load_next_obs:
+ assert self.goal_mode != "last" # we use last next_obs as goal
+
+ self.pad_seq_length = pad_seq_length
+ self.pad_frame_stack = pad_frame_stack
+ self.get_pad_mask = get_pad_mask
+
+ self.load_demo_info(filter_by_attribute=self.filter_by_attribute)
+
+ # maybe prepare for observation normalization
+ self.obs_normalization_stats = None
+ if self.hdf5_normalize_obs:
+ self.obs_normalization_stats = self.normalize_obs()
+
+ # prepare for action normalization
+ self.action_normalization_stats = None
+
+ # maybe store dataset in memory for fast access
+ if self.hdf5_cache_mode in ["all", "low_dim"]:
+ obs_keys_in_memory = self.obs_keys
+ if self.hdf5_cache_mode == "low_dim":
+ # only store low-dim observations
+ obs_keys_in_memory = []
+ for k in self.obs_keys:
+ if ObsUtils.key_is_obs_modality(k, "low_dim"):
+ obs_keys_in_memory.append(k)
+ self.obs_keys_in_memory = obs_keys_in_memory
+
+ self.hdf5_cache = self.load_dataset_in_memory(
+ demo_list=self.demos,
+ hdf5_file=self.hdf5_file,
+ obs_keys=self.obs_keys_in_memory,
+ dataset_keys=self.dataset_keys,
+ load_next_obs=self.load_next_obs
+ )
+
+ if self.hdf5_cache_mode == "all":
+ # cache getitem calls for even more speedup. We don't do this for
+ # "low-dim" since image observations require calls to getitem anyways.
+ print("SequenceDataset: caching get_item calls...")
+ self.getitem_cache = [self.get_item(i) for i in LogUtils.custom_tqdm(range(len(self)))]
+
+ # don't need the previous cache anymore
+ del self.hdf5_cache
+ self.hdf5_cache = None
+ else:
+ self.hdf5_cache = None
+
+ self.close_and_delete_hdf5_handle()
+
+ def load_demo_info(self, filter_by_attribute=None, demos=None):
+ """
+ Args:
+ filter_by_attribute (str): if provided, use the provided filter key
+ to select a subset of demonstration trajectories to load
+
+ demos (list): list of demonstration keys to load from the hdf5 file. If
+ omitted, all demos in the file (or under the @filter_by_attribute
+ filter key) are used.
+ """
+ # filter demo trajectory by mask
+ if demos is not None:
+ self.demos = demos
+ elif filter_by_attribute is not None:
+ self.demos = [elem.decode("utf-8") for elem in np.array(self.hdf5_file["mask/{}".format(filter_by_attribute)][:])]
+ else:
+ self.demos = list(self.hdf5_file["data"].keys())
+
+ # sort demo keys
+ inds = np.argsort([int(elem[5:]) for elem in self.demos])
+ self.demos = [self.demos[i] for i in inds]
+
+ self.n_demos = len(self.demos)
+
+ # keep internal index maps to know which transitions belong to which demos
+ self._index_to_demo_id = dict() # maps every index to a demo id
+ self._demo_id_to_start_indices = dict() # gives start index per demo id
+ self._demo_id_to_demo_length = dict()
+
+ # determine index mapping
+ self.total_num_sequences = 0
+ for ep in self.demos:
+ demo_length = self.hdf5_file["data/{}".format(ep)].attrs["num_samples"]
+ self._demo_id_to_start_indices[ep] = self.total_num_sequences
+ self._demo_id_to_demo_length[ep] = demo_length
+
+ num_sequences = demo_length
+ # determine actual number of sequences taking into account whether to pad for frame_stack and seq_length
+ if not self.pad_frame_stack:
+ num_sequences -= (self.n_frame_stack - 1)
+ if not self.pad_seq_length:
+ num_sequences -= (self.seq_length - 1)
+
+ if self.pad_seq_length:
+ assert demo_length >= 1 # sequence needs to have at least one sample
+ num_sequences = max(num_sequences, 1)
+ else:
+ assert num_sequences >= 1 # assume demo_length >= (self.n_frame_stack - 1 + self.seq_length)
+
+ for _ in range(num_sequences):
+ self._index_to_demo_id[self.total_num_sequences] = ep
+ self.total_num_sequences += 1
+
+ @property
+ def hdf5_file(self):
+ """
+ This property allows for a lazy hdf5 file open.
+ """
+ if self._hdf5_file is None:
+ self._hdf5_file = h5py.File(self.hdf5_path, 'r', swmr=self.hdf5_use_swmr, libver='latest')
+ return self._hdf5_file
+
+ def close_and_delete_hdf5_handle(self):
+ """
+ Maybe close the file handle.
+ """
+ if self._hdf5_file is not None:
+ self._hdf5_file.close()
+ self._hdf5_file = None
+
+ @contextmanager
+ def hdf5_file_opened(self):
+ """
+ Convenient context manager to open the file on entering the scope
+ and then close it on leaving.
+ """
+ should_close = self._hdf5_file is None
+ yield self.hdf5_file
+ if should_close:
+ self.close_and_delete_hdf5_handle()
+
+ def __del__(self):
+ self.close_and_delete_hdf5_handle()
+
+ def __repr__(self):
+ """
+ Pretty print the class and important attributes on a call to `print`.
+ """
+ msg = str(self.__class__.__name__)
+ msg += " (\n\tpath={}\n\tobs_keys={}\n\tseq_length={}\n\tfilter_key={}\n\tframe_stack={}\n"
+ msg += "\tpad_seq_length={}\n\tpad_frame_stack={}\n\tgoal_mode={}\n"
+ msg += "\tcache_mode={}\n"
+ msg += "\tnum_demos={}\n\tnum_sequences={}\n)"
+ filter_key_str = self.filter_by_attribute if self.filter_by_attribute is not None else "none"
+ goal_mode_str = self.goal_mode if self.goal_mode is not None else "none"
+ cache_mode_str = self.hdf5_cache_mode if self.hdf5_cache_mode is not None else "none"
+ msg = msg.format(self.hdf5_path, self.obs_keys, self.seq_length, filter_key_str, self.n_frame_stack,
+ self.pad_seq_length, self.pad_frame_stack, goal_mode_str, cache_mode_str,
+ self.n_demos, self.total_num_sequences)
+ return msg
+
+ def __len__(self):
+ """
+ Ensure that the torch dataloader will do a complete pass through all sequences in
+ the dataset before starting a new iteration.
+ """
+ return self.total_num_sequences
+
+ def load_dataset_in_memory(self, demo_list, hdf5_file, obs_keys, dataset_keys, load_next_obs):
+ """
+ Loads the hdf5 dataset into memory, preserving the structure of the file. Note that this
+ differs from `self.getitem_cache`, which, if active, actually caches the outputs of the
+ `getitem` operation.
+
+ Args:
+ demo_list (list): list of demo keys, e.g., 'demo_0'
+ hdf5_file (h5py.File): file handle to the hdf5 dataset.
+ obs_keys (list, tuple): observation keys to fetch, e.g., 'images'
+ dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions'
+ load_next_obs (bool): whether to load next_obs from the dataset
+
+ Returns:
+ all_data (dict): dictionary of loaded data.
+ """
+ all_data = dict()
+ print("SequenceDataset: loading dataset into memory...")
+ for ep in LogUtils.custom_tqdm(demo_list):
+ all_data[ep] = {}
+ all_data[ep]["attrs"] = {}
+ all_data[ep]["attrs"]["num_samples"] = hdf5_file["data/{}".format(ep)].attrs["num_samples"]
+ # get obs
+ all_data[ep]["obs"] = {k: hdf5_file["data/{}/obs/{}".format(ep, k)][()] for k in obs_keys}
+ if load_next_obs:
+ all_data[ep]["next_obs"] = {k: hdf5_file["data/{}/next_obs/{}".format(ep, k)][()] for k in obs_keys}
+ # get other dataset keys
+ for k in dataset_keys:
+ if k in hdf5_file["data/{}".format(ep)]:
+ all_data[ep][k] = hdf5_file["data/{}/{}".format(ep, k)][()].astype('float32')
+ else:
+ all_data[ep][k] = np.zeros((all_data[ep]["attrs"]["num_samples"], 1), dtype=np.float32)
+
+ if "model_file" in hdf5_file["data/{}".format(ep)].attrs:
+ all_data[ep]["attrs"]["model_file"] = hdf5_file["data/{}".format(ep)].attrs["model_file"]
+
+ return all_data
+
+ def normalize_obs(self):
+ """
+ Computes a dataset-wide mean and standard deviation for the observations
+ (per dimension and per obs key) and returns it.
+ """
+
+ # Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate
+ # with the previous statistics.
+ ep = self.demos[0]
+ obs_traj = {k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype('float32') for k in self.obs_keys}
+ obs_traj = ObsUtils.process_obs_dict(obs_traj)
+ merged_stats = _compute_traj_stats(obs_traj)
+ print("SequenceDataset: normalizing observations...")
+ for ep in LogUtils.custom_tqdm(self.demos[1:]):
+ obs_traj = {k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype('float32') for k in self.obs_keys}
+ obs_traj = ObsUtils.process_obs_dict(obs_traj)
+ traj_stats = _compute_traj_stats(obs_traj)
+ merged_stats = _aggregate_traj_stats(merged_stats, traj_stats)
+
+ obs_normalization_stats = { k : {} for k in merged_stats }
+ for k in merged_stats:
+ # note we add a small tolerance of 1e-3 for std
+ obs_normalization_stats[k]["mean"] = merged_stats[k]["mean"].astype(np.float32)
+ obs_normalization_stats[k]["std"] = (np.sqrt(merged_stats[k]["sqdiff"] / merged_stats[k]["n"]) + 1e-3).astype(np.float32)
+ return obs_normalization_stats
+
+ def get_obs_normalization_stats(self):
+ """
+ Returns dictionary of mean and std for each observation key if using
+ observation normalization, otherwise None.
+
+ Returns:
+ obs_normalization_stats (dict): a dictionary for observation
+ normalization. This maps observation keys to dicts
+ with a "mean" and "std" of shape (1, ...) where ... is the default
+ shape for the observation.
+ """
+ assert self.hdf5_normalize_obs, "not using observation normalization!"
+ return deepcopy(self.obs_normalization_stats)
+
+ def get_action_traj(self, ep):
+ action_traj = dict()
+ for key in self.action_keys:
+ action_traj[key] = self.hdf5_file["data/{}/{}".format(ep, key)][()].astype('float32')
+ return action_traj
+
+ def get_action_stats(self):
+ ep = self.demos[0]
+ action_traj = self.get_action_traj(ep)
+ action_stats = _compute_traj_stats(action_traj)
+ print("SequenceDataset: normalizing actions...")
+ for ep in LogUtils.custom_tqdm(self.demos[1:]):
+ action_traj = self.get_action_traj(ep)
+ traj_stats = _compute_traj_stats(action_traj)
+ action_stats = _aggregate_traj_stats(action_stats, traj_stats)
+ return action_stats
+
+ def set_action_normalization_stats(self, action_normalization_stats):
+ self.action_normalization_stats = action_normalization_stats
+
+ def get_action_normalization_stats(self):
+ """
+ Computes a dataset-wide min, max, mean and standard deviation for the actions
+ (per dimension) and returns it.
+ """
+
+ # Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate
+ # with the previous statistics.
+ if self.action_normalization_stats is None:
+ action_stats = self.get_action_stats()
+ self.action_normalization_stats = action_stats_to_normalization_stats(
+ action_stats, self.action_config)
+ return self.action_normalization_stats
+
+ def get_dataset_for_ep(self, ep, key):
+ """
+ Helper utility to get a dataset for a specific demonstration.
+ Takes into account whether the dataset has been loaded into memory.
+ """
+
+ # check if this key should be in memory
+ key_should_be_in_memory = (self.hdf5_cache_mode in ["all", "low_dim"])
+ if key_should_be_in_memory:
+ # if key is an observation, it may not be in memory
+ if '/' in key:
+ key1, key2 = key.split('/')
+ assert(key1 in ['obs', 'next_obs', 'action_dict'])
+ if key2 not in self.obs_keys_in_memory:
+ key_should_be_in_memory = False
+
+ if key_should_be_in_memory:
+ # read cache
+ if '/' in key:
+ key1, key2 = key.split('/')
+ assert(key1 in ['obs', 'next_obs', 'action_dict'])
+ ret = self.hdf5_cache[ep][key1][key2]
+ else:
+ ret = self.hdf5_cache[ep][key]
+ else:
+ # read from file
+ hd5key = "data/{}/{}".format(ep, key)
+ ret = self.hdf5_file[hd5key]
+ return ret
+
+ def __getitem__(self, index):
+ """
+ Fetch dataset sequence @index (inferred through internal index map), using the getitem_cache if available.
+ """
+ if self.hdf5_cache_mode == "all":
+ return self.getitem_cache[index]
+ return self.get_item(index)
+
+ def get_item(self, index):
+ """
+ Main implementation of getitem when not using cache.
+ """
+
+ demo_id = self._index_to_demo_id[index]
+ demo_start_index = self._demo_id_to_start_indices[demo_id]
+ demo_length = self._demo_id_to_demo_length[demo_id]
+
+ # start at offset index if not padding for frame stacking
+ demo_index_offset = 0 if self.pad_frame_stack else (self.n_frame_stack - 1)
+ index_in_demo = index - demo_start_index + demo_index_offset
+
+ # end at offset index if not padding for seq length
+ demo_length_offset = 0 if self.pad_seq_length else (self.seq_length - 1)
+ end_index_in_demo = demo_length - demo_length_offset
+
+ meta = self.get_dataset_sequence_from_demo(
+ demo_id,
+ index_in_demo=index_in_demo,
+ keys=self.dataset_keys,
+ num_frames_to_stack=self.n_frame_stack - 1, # note: need to decrement self.n_frame_stack by one
+ seq_length=self.seq_length
+ )
+
+ # determine goal index
+ goal_index = None
+ if self.goal_mode == "last":
+ goal_index = end_index_in_demo - 1
+
+ meta["obs"] = self.get_obs_sequence_from_demo(
+ demo_id,
+ index_in_demo=index_in_demo,
+ keys=self.obs_keys,
+ num_frames_to_stack=self.n_frame_stack - 1,
+ seq_length=self.seq_length,
+ prefix="obs"
+ )
+
+ if self.load_next_obs:
+ meta["next_obs"] = self.get_obs_sequence_from_demo(
+ demo_id,
+ index_in_demo=index_in_demo,
+ keys=self.obs_keys,
+ num_frames_to_stack=self.n_frame_stack - 1,
+ seq_length=self.seq_length,
+ prefix="next_obs"
+ )
+
+ if goal_index is not None:
+ goal = self.get_obs_sequence_from_demo(
+ demo_id,
+ index_in_demo=goal_index,
+ keys=self.obs_keys,
+ num_frames_to_stack=0,
+ seq_length=1,
+ prefix="next_obs",
+ )
+ meta["goal_obs"] = {k: goal[k][0] for k in goal} # remove sequence dimension for goal
+
+ # get action components
+ ac_dict = OrderedDict()
+ for k in self.action_keys:
+ ac = meta[k]
+ # expand action shape if needed
+ if len(ac.shape) == 1:
+ ac = ac.reshape(-1, 1)
+ ac_dict[k] = ac
+
+ # normalize actions
+ action_normalization_stats = self.get_action_normalization_stats()
+ ac_dict = ObsUtils.normalize_dict(ac_dict, normalization_stats=action_normalization_stats)
+
+ # concatenate all action components
+ meta["actions"] = AcUtils.action_dict_to_vector(ac_dict)
+
+ # also return the sampled index
+ meta["index"] = index
+
+ return meta
+
+ def get_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1):
+ """
+ Extract a (sub)sequence of data items from a demo given the @keys of the items.
+
+ Args:
+ demo_id (str): id of the demo, e.g., demo_0
+ index_in_demo (int): beginning index of the sequence wrt the demo
+ keys (tuple): list of keys to extract
+ num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range
+ seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range
+
+ Returns:
+ a dictionary of extracted items.
+ """
+ assert num_frames_to_stack >= 0
+ assert seq_length >= 1
+
+ demo_length = self._demo_id_to_demo_length[demo_id]
+ assert index_in_demo < demo_length
+
+ # determine begin and end of sequence
+ seq_begin_index = max(0, index_in_demo - num_frames_to_stack)
+ seq_end_index = min(demo_length, index_in_demo + seq_length)
+
+ # determine sequence padding
+ seq_begin_pad = max(0, num_frames_to_stack - index_in_demo) # pad for frame stacking
+ seq_end_pad = max(0, index_in_demo + seq_length - demo_length) # pad for sequence length
+
+ # make sure we are not padding if specified.
+ if not self.pad_frame_stack:
+ assert seq_begin_pad == 0
+ if not self.pad_seq_length:
+ assert seq_end_pad == 0
+
+ # fetch observation from the dataset file
+ seq = dict()
+ for k in keys:
+ data = self.get_dataset_for_ep(demo_id, k)
+ seq[k] = data[seq_begin_index: seq_end_index]
+
+ seq = TensorUtils.pad_sequence(seq, padding=(seq_begin_pad, seq_end_pad), pad_same=True)
+ pad_mask = np.array([0] * seq_begin_pad + [1] * (seq_end_index - seq_begin_index) + [0] * seq_end_pad)
+ pad_mask = pad_mask[:, None].astype(bool)
+
+ return seq, pad_mask
+
+ def get_obs_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1, prefix="obs"):
+ """
+ Extract a (sub)sequence of observation items from a demo given the @keys of the items.
+
+ Args:
+ demo_id (str): id of the demo, e.g., demo_0
+ index_in_demo (int): beginning index of the sequence wrt the demo
+ keys (tuple): list of keys to extract
+ num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range
+ seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range
+ prefix (str): one of "obs", "next_obs"
+
+ Returns:
+ a dictionary of extracted items.
+ """
+ obs, pad_mask = self.get_sequence_from_demo(
+ demo_id,
+ index_in_demo=index_in_demo,
+ keys=tuple('{}/{}'.format(prefix, k) for k in keys),
+ num_frames_to_stack=num_frames_to_stack,
+ seq_length=seq_length,
+ )
+ obs = {'/'.join(k.split('/')[1:]): obs[k] for k in obs} # strip the prefix
+ if self.get_pad_mask:
+ obs["pad_mask"] = pad_mask
+
+ return obs
+
+ def get_dataset_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1):
+ """
+ Extract a (sub)sequence of dataset items from a demo given the @keys of the items (e.g., states, actions).
+
+ Args:
+ demo_id (str): id of the demo, e.g., demo_0
+ index_in_demo (int): beginning index of the sequence wrt the demo
+ keys (tuple): list of keys to extract
+ num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range
+ seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range
+
+ Returns:
+ a dictionary of extracted items.
+ """
+ data, pad_mask = self.get_sequence_from_demo(
+ demo_id,
+ index_in_demo=index_in_demo,
+ keys=keys,
+ num_frames_to_stack=num_frames_to_stack,
+ seq_length=seq_length,
+ )
+ if self.get_pad_mask:
+ data["pad_mask"] = pad_mask
+ return data
+
+ def get_trajectory_at_index(self, index):
+ """
+ Method provided as a utility to get an entire trajectory, given
+ the corresponding @index.
+ """
+ demo_id = self.demos[index]
+ demo_length = self._demo_id_to_demo_length[demo_id]
+
+ meta = self.get_dataset_sequence_from_demo(
+ demo_id,
+ index_in_demo=0,
+ keys=self.dataset_keys,
+ num_frames_to_stack=self.n_frame_stack - 1, # note: need to decrement self.n_frame_stack by one
+ seq_length=demo_length
+ )
+ meta["obs"] = self.get_obs_sequence_from_demo(
+ demo_id,
+ index_in_demo=0,
+ keys=self.obs_keys,
+ seq_length=demo_length
+ )
+ if self.load_next_obs:
+ meta["next_obs"] = self.get_obs_sequence_from_demo(
+ demo_id,
+ index_in_demo=0,
+ keys=self.obs_keys,
+ seq_length=demo_length,
+ prefix="next_obs"
+ )
+
+ meta["ep"] = demo_id
+ return meta
+
+ def get_dataset_sampler(self):
+ """
+ Return instance of torch.utils.data.Sampler or None. Allows
+ for dataset to define custom sampling logic, such as
+ re-weighting the probability of samples being drawn.
+ See the `train` function in scripts/train.py, and torch
+ `DataLoader` documentation, for more info.
+ """
+ return None
+
+
+class R2D2Dataset(SequenceDataset):
+ def get_action_traj(self, ep):
+ action_traj = dict()
+ for key in self.action_keys:
+ action_traj[key] = self.hdf5_file[key][()].astype('float32')
+ if len(action_traj[key].shape) == 1:
+ action_traj[key] = np.reshape(action_traj[key], (-1, 1))
+
+ return action_traj
+
+ def load_demo_info(self, filter_by_attribute=None, demos=None, n_demos=None):
+ """
+ Args:
+ filter_by_attribute (str): if provided, use the provided filter key
+ to select a subset of demonstration trajectories to load
+
+ demos (list): list of demonstration keys to load from the hdf5 file. If
+ omitted, all demos in the file (or under the @filter_by_attribute
+ filter key) are used.
+ """
+
+ self.demos = ["demo"]
+
+ self.n_demos = len(self.demos)
+
+ # keep internal index maps to know which transitions belong to which demos
+ self._index_to_demo_id = dict() # maps every index to a demo id
+ self._demo_id_to_start_indices = dict() # gives start index per demo id
+ self._demo_id_to_demo_length = dict()
+
+ # segment time stamps
+ self._demo_id_to_segments = dict()
+
+ ep = self.demos[0]
+
+ # determine index mapping
+ self.total_num_sequences = 0
+ demo_length = self.hdf5_file["action/cartesian_velocity"].shape[0]
+ self._demo_id_to_start_indices[ep] = self.total_num_sequences
+ self._demo_id_to_demo_length[ep] = demo_length
+
+ # seperate demo into segments for better alignment
+ gripper_actions = list(self.hdf5_file["action/gripper_position"])
+ gripper_closed = [1 if x > 0 else 0 for x in gripper_actions]
+
+ try:
+ # find when the gripper fist opens/closes
+ gripper_close = gripper_closed.index(1)
+ gripper_open = gripper_close + gripper_closed[gripper_close:].index(0)
+ except ValueError:
+ # special case for (invalid) trajectories
+ gripper_close, gripper_open = int(demo_length / 3), int(demo_length / 3 * 2)
+ print("No gripper action:", gripper_actions)
+ self._demo_id_to_segments[ep] = [0, gripper_close, gripper_open, demo_length - 1]
+
+ num_sequences = demo_length
+ # determine actual number of sequences taking into account whether to pad for frame_stack and seq_length
+ if not self.pad_frame_stack:
+ num_sequences -= (self.n_frame_stack - 1)
+ if not self.pad_seq_length:
+ num_sequences -= (self.seq_length - 1)
+
+ if self.pad_seq_length:
+ assert demo_length >= 1 # sequence needs to have at least one sample
+ num_sequences = max(num_sequences, 1)
+ else:
+ assert num_sequences >= 1 # assume demo_length >= (self.n_frame_stack - 1 + self.seq_length)
+
+ for _ in range(num_sequences):
+ self._index_to_demo_id[self.total_num_sequences] = ep
+ self.total_num_sequences += 1
+
+ def load_dataset_in_memory(self, demo_list, hdf5_file, obs_keys, dataset_keys, load_next_obs):
+ """
+ Loads the hdf5 dataset into memory, preserving the structure of the file. Note that this
+ differs from `self.getitem_cache`, which, if active, actually caches the outputs of the
+ `getitem` operation.
+
+ Args:
+ demo_list (list): list of demo keys, e.g., 'demo_0'
+ hdf5_file (h5py.File): file handle to the hdf5 dataset.
+ obs_keys (list, tuple): observation keys to fetch, e.g., 'images'
+ dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions'
+ load_next_obs (bool): whether to load next_obs from the dataset
+
+ Returns:
+ all_data (dict): dictionary of loaded data.
+ """
+ all_data = dict()
+ print("SequenceDataset: loading dataset into memory...")
+
+ for ep in LogUtils.custom_tqdm(demo_list):
+ all_data[ep] = {}
+ all_data[ep]["attrs"] = {}
+ all_data[ep]["attrs"]["num_samples"] = hdf5_file["action/cartesian_velocity"].shape[0] # hack to get traj len
+ # get obs
+ all_data[ep]["obs"] = {k: hdf5_file["observation/{}".format(k)][()].astype('float32') for k in obs_keys}
+ if load_next_obs:
+ raise NotImplementedError
+ # get other dataset keys
+ for k in dataset_keys:
+ if k in hdf5_file.keys():
+ all_data[ep][k] = hdf5_file["{}".format(k)][()].astype('float32')
+ else:
+ raise NotImplementedError
+
+ return all_data
+
+ def get_dataset_for_ep(self, ep, key, try_to_use_cache=True):
+ """
+ Helper utility to get a dataset for a specific demonstration.
+ Takes into account whether the dataset has been loaded into memory.
+ """
+
+ # check if this key should be in memory
+ key_should_be_in_memory = try_to_use_cache and (self.hdf5_cache_mode in ["all", "low_dim"])
+ if key_should_be_in_memory:
+ # if key is an observation, it may not be in memory
+ if '/' in key:
+ key_splits = key.split('/')
+ key1 = key_splits[0]
+ key2 = "/".join(key_splits[1:])
+ if key1 == "observation" and key2 not in self.obs_keys_in_memory:
+ key_should_be_in_memory = False
+
+ if key_should_be_in_memory:
+ # read cache
+ if '/' in key:
+ key_splits = key.split('/')
+ key1 = key_splits[0]
+ key2 = "/".join(key_splits[1:])
+ if key1 == "observation":
+ ret = self.hdf5_cache[ep]["obs"][key2]
+ else:
+ ret = self.hdf5_cache[ep][key]
+ else:
+ ret = self.hdf5_cache[ep][key]
+ else:
+ # read from file
+ hd5key = "{}".format(key) #"data/{}/{}".format(ep, key)
+ ret = self.hdf5_file[hd5key]
+ return ret
+
+
+ def get_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1):
+ """
+ Extract a (sub)sequence of data items from a demo given the @keys of the items.
+
+ Args:
+ demo_id (str): id of the demo, e.g., demo_0
+ index_in_demo (int): beginning index of the sequence wrt the demo
+ keys (tuple): list of keys to extract
+ num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range
+ seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range
+
+ Returns:
+ a dictionary of extracted items.
+ """
+ assert num_frames_to_stack >= 0
+ assert seq_length >= 1
+
+ demo_length = self._demo_id_to_demo_length[demo_id]
+ assert index_in_demo < demo_length
+
+ # determine begin and end of sequence
+ seq_begin_index = max(0, index_in_demo - num_frames_to_stack)
+ seq_end_index = min(demo_length, index_in_demo + seq_length)
+
+ # determine sequence padding
+ seq_begin_pad = max(0, num_frames_to_stack - index_in_demo) # pad for frame stacking
+ seq_end_pad = max(0, index_in_demo + seq_length - demo_length) # pad for sequence length
+
+ # make sure we are not padding if specified.
+ if not self.pad_frame_stack:
+ assert seq_begin_pad == 0
+ if not self.pad_seq_length:
+ assert seq_end_pad == 0
+
+ # fetch observation from the dataset file
+ seq = dict()
+ for k in keys:
+ data = self.get_dataset_for_ep(demo_id, k)
+ seq[k] = data[seq_begin_index: seq_end_index].astype("float32")
+
+ seq = TensorUtils.pad_sequence(seq, padding=(seq_begin_pad, seq_end_pad), pad_same=True)
+ pad_mask = np.array([0] * seq_begin_pad + [1] * (seq_end_index - seq_begin_index) + [0] * seq_end_pad)
+ pad_mask = pad_mask[:, None].astype(bool)
+
+ return seq, pad_mask
+
+
+ def get_item(self, index):
+ """
+ Main implementation of getitem when not using cache.
+ """
+
+ demo_id = self._index_to_demo_id[index]
+ demo_start_index = self._demo_id_to_start_indices[demo_id]
+ demo_length = self._demo_id_to_demo_length[demo_id]
+
+ # start at offset index if not padding for frame stacking
+ demo_index_offset = 0 if self.pad_frame_stack else (self.n_frame_stack - 1)
+ index_in_demo = index - demo_start_index + demo_index_offset
+
+ # end at offset index if not padding for seq length
+ demo_length_offset = 0 if self.pad_seq_length else (self.seq_length - 1)
+ end_index_in_demo = demo_length - demo_length_offset
+
+ meta = self.get_dataset_sequence_from_demo(
+ demo_id,
+ index_in_demo=index_in_demo,
+ keys=self.dataset_keys,
+ num_frames_to_stack=self.n_frame_stack - 1,
+ seq_length=self.seq_length,
+ )
+
+ # determine goal index
+ goal_index = None
+ if self.goal_mode == "last":
+ goal_index = end_index_in_demo - 1
+
+ meta["obs"] = self.get_obs_sequence_from_demo(
+ demo_id,
+ index_in_demo=index_in_demo,
+ keys=self.obs_keys,
+ num_frames_to_stack=self.n_frame_stack - 1,
+ seq_length=self.seq_length,
+ prefix="observation"
+ )
+
+ if self.load_next_obs:
+ meta["next_obs"] = self.get_obs_sequence_from_demo(
+ demo_id,
+ index_in_demo=index_in_demo,
+ keys=self.obs_keys,
+ num_frames_to_stack=self.n_frame_stack - 1,
+ seq_length=self.seq_length,
+ prefix="next_obs"
+ )
+
+ if goal_index is not None:
+ goal = self.get_obs_sequence_from_demo(
+ demo_id,
+ index_in_demo=goal_index,
+ keys=self.obs_keys,
+ num_frames_to_stack=0,
+ seq_length=1,
+ prefix="next_obs",
+ )
+ meta["goal_obs"] = {k: goal[k][0] for k in goal} # remove sequence dimension for goal
+
+ # get action components
+ ac_dict = OrderedDict()
+ for k in self.action_keys:
+ ac = meta[k]
+ # expand action shape if needed
+ if len(ac.shape) == 1:
+ ac = ac.reshape(-1, 1)
+ ac_dict[k] = ac
+
+ # normalize actions
+ action_normalization_stats = self.get_action_normalization_stats()
+ ac_dict = ObsUtils.normalize_dict(ac_dict, normalization_stats=action_normalization_stats)
+
+ # concatenate all action components
+ meta["actions"] = AcUtils.action_dict_to_vector(ac_dict)
+
+ # keys to reshape
+ for k in meta["obs"]:
+ if len(meta["obs"][k].shape) == 1:
+ meta["obs"][k] = np.expand_dims(meta["obs"][k], axis=1)
+
+ # also return the sampled index
+ meta["index"] = index
+
+ return meta
+
+
+class MetaDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ datasets,
+ ds_weights,
+ normalize_weights_by_ds_size=False,
+ ds_labels=None,
+ ):
+ super(MetaDataset, self).__init__()
+ self.datasets = datasets
+ ds_lens = np.array([len(ds) for ds in self.datasets])
+ if normalize_weights_by_ds_size:
+ self.ds_weights = np.array(ds_weights) / ds_lens
+ else:
+ self.ds_weights = ds_weights
+ self._ds_ind_bins = np.cumsum([0] + list(ds_lens))
+
+ # cache mode "all" not supported! The action normalization stats of each
+ # dataset will change after the datasets are already initialized
+ for ds in self.datasets:
+ assert ds.hdf5_cache_mode != "all"
+
+ # compute ds_labels to one hot ids
+ if ds_labels is None:
+ self.ds_labels = ["dummy"]
+ else:
+ self.ds_labels = ds_labels
+
+ unique_labels = sorted(set(self.ds_labels))
+
+ self.ds_labels_to_ids = {}
+ for i, label in enumerate(sorted(unique_labels)):
+ one_hot_id = np.zeros(len(unique_labels))
+ one_hot_id[i] = 1.0
+ self.ds_labels_to_ids[label] = one_hot_id
+
+ # TODO: comment
+ action_stats = self.get_action_stats()
+ self.action_normalization_stats = action_stats_to_normalization_stats(
+ action_stats, self.datasets[0].action_config)
+ self.set_action_normalization_stats(self.action_normalization_stats)
+
+ def __len__(self):
+ return np.sum([len(ds) for ds in self.datasets])
+
+ def __getitem__(self, idx):
+ ds_ind = np.digitize(idx, self._ds_ind_bins) - 1
+ ind_in_ds = idx - self._ds_ind_bins[ds_ind]
+ meta = self.datasets[ds_ind].__getitem__(ind_in_ds)
+ meta["index"] = idx
+ ds_label = self.ds_labels[ds_ind]
+ T = meta["actions"].shape[0]
+ return meta
+
+ def get_ds_label(self, idx):
+ ds_ind = np.digitize(idx, self._ds_ind_bins) - 1
+ ds_label = self.ds_labels[ds_ind]
+ return ds_label
+
+ def get_ds_id(self, idx):
+ ds_ind = np.digitize(idx, self._ds_ind_bins) - 1
+ ds_label = self.ds_labels[ds_ind]
+ return self.ds_labels_to_ids[ds_label]
+
+ def __repr__(self):
+ str_output = '\n'.join([ds.__repr__() for ds in self.datasets])
+ return str_output
+
+ def get_dataset_sampler(self):
+ weights = np.ones(len(self))
+ for i, (start, end) in enumerate(zip(self._ds_ind_bins[:-1], self._ds_ind_bins[1:])):
+ weights[start:end] = self.ds_weights[i]
+
+ sampler = torch.utils.data.WeightedRandomSampler(
+ weights=weights,
+ num_samples=len(self),
+ replacement=True,
+ )
+ return sampler
+
+ def get_action_stats(self):
+ meta_action_stats = self.datasets[0].get_action_stats()
+ for dataset in self.datasets[1:]:
+ ds_action_stats = dataset.get_action_stats()
+ meta_action_stats = _aggregate_traj_stats(meta_action_stats, ds_action_stats)
+
+ return meta_action_stats
+
+ def set_action_normalization_stats(self, action_normalization_stats):
+ self.action_normalization_stats = action_normalization_stats
+ for ds in self.datasets:
+ ds.set_action_normalization_stats(self.action_normalization_stats)
+
+ def get_action_normalization_stats(self):
+ """
+ Computes a dataset-wide min, max, mean and standard deviation for the actions
+ (per dimension) and returns it.
+ """
+
+ # Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate
+ # with the previous statistics.
+ if self.action_normalization_stats is None:
+ action_stats = self.get_action_stats()
+ self.action_normalization_stats = action_stats_to_normalization_stats(
+ action_stats, self.datasets[0].action_config)
+ return self.action_normalization_stats
+
+def _compute_traj_stats(traj_obs_dict):
+ """
+ Helper function to compute statistics over a single trajectory of observations.
+ """
+ traj_stats = { k : {} for k in traj_obs_dict }
+ for k in traj_obs_dict:
+ traj_stats[k]["n"] = traj_obs_dict[k].shape[0]
+ traj_stats[k]["mean"] = traj_obs_dict[k].mean(axis=0, keepdims=True) # [1, ...]
+ traj_stats[k]["sqdiff"] = ((traj_obs_dict[k] - traj_stats[k]["mean"]) ** 2).sum(axis=0, keepdims=True) # [1, ...]
+ traj_stats[k]["min"] = traj_obs_dict[k].min(axis=0, keepdims=True)
+ traj_stats[k]["max"] = traj_obs_dict[k].max(axis=0, keepdims=True)
+ return traj_stats
+
+def _aggregate_traj_stats(traj_stats_a, traj_stats_b):
+ """
+ Helper function to aggregate trajectory statistics.
+ See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
+ for more information.
+ """
+ merged_stats = {}
+ for k in traj_stats_a:
+ n_a, avg_a, M2_a, min_a, max_a = traj_stats_a[k]["n"], traj_stats_a[k]["mean"], traj_stats_a[k]["sqdiff"], traj_stats_a[k]["min"], traj_stats_a[k]["max"]
+ n_b, avg_b, M2_b, min_b, max_b = traj_stats_b[k]["n"], traj_stats_b[k]["mean"], traj_stats_b[k]["sqdiff"], traj_stats_b[k]["min"], traj_stats_b[k]["max"]
+ n = n_a + n_b
+ mean = (n_a * avg_a + n_b * avg_b) / n
+ delta = (avg_b - avg_a)
+ M2 = M2_a + M2_b + (delta ** 2) * (n_a * n_b) / n
+ min_ = np.minimum(min_a, min_b)
+ max_ = np.maximum(max_a, max_b)
+ merged_stats[k] = dict(n=n, mean=mean, sqdiff=M2, min=min_, max=max_)
+ return merged_stats
+
+def action_stats_to_normalization_stats(action_stats, action_config):
+ action_normalization_stats = OrderedDict()
+ for action_key in action_stats.keys():
+ # get how this action should be normalized from config, default to None
+ norm_method = action_config[action_key].get("normalization", None)
+ if norm_method is None:
+ # no normalization, unit scale, zero offset
+ action_normalization_stats[action_key] = {
+ "scale": np.ones_like(action_stats[action_key]["mean"], dtype=np.float32),
+ "offset": np.zeros_like(action_stats[action_key]["mean"], dtype=np.float32)
+ }
+ elif norm_method == "min_max":
+ # normalize min to -1 and max to 1
+ range_eps = 1e-4
+ input_min = action_stats[action_key]["min"].astype(np.float32)
+ input_max = action_stats[action_key]["max"].astype(np.float32)
+ # instead of -1 and 1 use numbers just below threshold to prevent numerical instability issues
+ output_min = -0.999999
+ output_max = 0.999999
+
+ # ignore input dimentions that is too small to prevent division by zero
+ input_range = input_max - input_min
+ ignore_dim = input_range < range_eps
+ input_range[ignore_dim] = output_max - output_min
+
+ # expected usage of scale and offset
+ # normalized_action = (raw_action - offset) / scale
+ # raw_action = scale * normalized_action + offset
+
+ # eq1: input_max = scale * output_max + offset
+ # eq2: input_min = scale * output_min + offset
+
+ # solution for scale and offset
+ # eq1 - eq2:
+ # input_max - input_min = scale * (output_max - output_min)
+ # (input_max - input_min) / (output_max - output_min) = scale <- eq3
+ # offset = input_min - scale * output_min <- eq4
+ scale = input_range / (output_max - output_min)
+ offset = input_min - scale * output_min
+
+ offset[ignore_dim] = input_min[ignore_dim] - (output_max + output_min) / 2
+
+ action_normalization_stats[action_key] = {
+ "scale": scale,
+ "offset": offset
+ }
+ elif norm_method == "gaussian":
+ # normalize to zero mean unit variance
+ input_mean = action_stats[action_key]["mean"].astype(np.float32)
+ input_std = np.sqrt(action_stats[action_key]["sqdiff"] / action_stats[action_key]["n"]).astype(np.float32)
+
+ # ignore input dimentions that is too small to prevent division by zero
+ std_eps = 1e-6
+ ignore_dim = input_std < std_eps
+ input_std[ignore_dim] = 1.0
+
+ action_normalization_stats[action_key] = {
+ "scale": input_mean,
+ "offset": input_std
+ }
+ else:
+ raise NotImplementedError(
+ 'action_config.actions.normalization: "{}" is not supported'.format(norm_method))
+
+ return action_normalization_stats
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/env_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/env_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..61c7500daaf7026977005e798a483049695611e4
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/env_utils.py
@@ -0,0 +1,385 @@
+"""
+This file contains several utility functions for working with environment
+wrappers provided by the repository, and with environment metadata saved
+in dataset files.
+"""
+from copy import deepcopy
+import robomimic.envs.env_base as EB
+from robomimic.utils.log_utils import log_warning
+
+
+def get_env_class(env_meta=None, env_type=None, env=None):
+ """
+ Return env class from either env_meta, env_type, or env.
+ Note the use of lazy imports - this ensures that modules are only
+ imported when the corresponding env type is requested. This can
+ be useful in practice. For example, a training run that only
+ requires access to gym environments should not need to import
+ robosuite.
+
+ Args:
+ env_meta (dict): environment metadata, which should be loaded from demonstration
+ hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
+ @FileUtils.env_from_checkpoint). Contains 3 keys:
+
+ :`'env_name'`: name of environment
+ :`'type'`: type of environment, should be a value in EB.EnvType
+ :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
+
+ env_type (int): the type of environment, which determines the env class that will
+ be instantiated. Should be a value in EB.EnvType.
+
+ env (instance of EB.EnvBase): environment instance
+ """
+ env_type = get_env_type(env_meta=env_meta, env_type=env_type, env=env)
+ if env_type == EB.EnvType.ROBOSUITE_TYPE:
+ from robomimic.envs.env_robosuite import EnvRobosuite
+ return EnvRobosuite
+ elif env_type == EB.EnvType.GYM_TYPE:
+ from robomimic.envs.env_gym import EnvGym
+ return EnvGym
+ elif env_type == EB.EnvType.IG_MOMART_TYPE:
+ from robomimic.envs.env_ig_momart import EnvGibsonMOMART
+ return EnvGibsonMOMART
+ elif env_type == EB.EnvType.REAL_TYPE:
+ from robomimic.envs.env_real_panda import EnvRealPanda
+ return EnvRealPanda
+ elif env_type == EB.EnvType.GPRS_REAL_TYPE:
+ from robomimic.envs.env_real_panda_gprs import EnvRealPandaGPRS
+ return EnvRealPandaGPRS
+ raise Exception("code should never reach this point")
+
+
+def get_env_type(env_meta=None, env_type=None, env=None):
+ """
+ Helper function to get env_type from a variety of inputs.
+
+ Args:
+ env_meta (dict): environment metadata, which should be loaded from demonstration
+ hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
+ @FileUtils.env_from_checkpoint). Contains 3 keys:
+
+ :`'env_name'`: name of environment
+ :`'type'`: type of environment, should be a value in EB.EnvType
+ :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
+
+ env_type (int): the type of environment, which determines the env class that will
+ be instantiated. Should be a value in EB.EnvType.
+
+ env (instance of EB.EnvBase): environment instance
+ """
+ checks = [(env_meta is not None), (env_type is not None), (env is not None)]
+ assert sum(checks) == 1, "should provide only one of env_meta, env_type, env"
+ if env_meta is not None:
+ env_type = env_meta["type"]
+ elif env is not None:
+ env_type = env.type
+ return env_type
+
+
+def check_env_type(type_to_check, env_meta=None, env_type=None, env=None):
+ """
+ Checks whether the passed env_meta, env_type, or env is of type @type_to_check.
+ Type corresponds to EB.EnvType.
+
+ Args:
+ type_to_check (int): type to check equality against
+
+ env_meta (dict): environment metadata, which should be loaded from demonstration
+ hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
+ @FileUtils.env_from_checkpoint). Contains 3 keys:
+
+ :`'env_name'`: name of environment
+ :`'type'`: type of environment, should be a value in EB.EnvType
+ :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
+
+ env_type (int): the type of environment, which determines the env class that will
+ be instantiated. Should be a value in EB.EnvType.
+
+ env (instance of EB.EnvBase): environment instance
+ """
+ env_type = get_env_type(env_meta=env_meta, env_type=env_type, env=env)
+ return (env_type == type_to_check)
+
+
+def check_env_version(env, env_meta):
+ """
+ Checks whether the passed env and env_meta dictionary having matching environment versions.
+ Logs warning if cannot find version or versions do not match.
+
+ Args:
+ env (instance of EB.EnvBase): environment instance
+
+ env_meta (dict): environment metadata, which should be loaded from demonstration
+ hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
+ @FileUtils.env_from_checkpoint). Contains following key:
+
+ :`'env_version'`: environment version, type str
+ """
+ env_system_version = env.version
+ env_meta_version = env_meta.get("env_version", None)
+
+ if env_meta_version is None:
+ log_warning(
+ "No environment version found in dataset!"\
+ "\nCannot verify if dataset and installed environment versions match"\
+ )
+ elif env_system_version != env_meta_version:
+ log_warning(
+ "Dataset and installed environment version mismatch!"\
+ "\nDataset environment version: {meta}"\
+ "\nInstalled environment version: {sys}".format(
+ sys=env_system_version,
+ meta=env_meta_version,
+ )
+ )
+
+
+def is_robosuite_env(env_meta=None, env_type=None, env=None):
+ """
+ Determines whether the environment is a robosuite environment. Accepts
+ either env_meta, env_type, or env.
+ """
+ return check_env_type(type_to_check=EB.EnvType.ROBOSUITE_TYPE, env_meta=env_meta, env_type=env_type, env=env)
+
+
+def is_simpler_env(env_meta=None, env_type=None, env=None):
+ return False
+
+
+def is_simpler_ov_env(env_meta=None, env_type=None, env=None):
+ return False
+
+
+def is_factory_env(env_meta=None, env_type=None, env=None):
+ return False
+
+
+def is_furniture_sim_env(env_meta=None, env_type=None, env=None):
+ return False
+
+
+def is_real_robot_env(env_meta=None, env_type=None, env=None):
+ """
+ Determines whether the environment is a real robot environment. Accepts
+ either env_meta, env_type, or env.
+ """
+ return check_env_type(type_to_check=EB.EnvType.REAL_TYPE, env_meta=env_meta, env_type=env_type, env=env)
+
+
+def is_real_robot_gprs_env(env_meta=None, env_type=None, env=None):
+ """
+ Determines whether the environment is a real robot environment. Accepts
+ either env_meta, env_type, or env.
+ """
+ return check_env_type(type_to_check=EB.EnvType.GPRS_REAL_TYPE, env_meta=env_meta, env_type=env_type, env=env)
+
+
+def create_env(
+ env_type,
+ env_name,
+ env_class=None,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=False,
+ use_depth_obs=False,
+ **kwargs,
+):
+ """
+ Create environment.
+
+ Args:
+ env_type (int): the type of environment, which determines the env class that will
+ be instantiated. Should be a value in EB.EnvType.
+
+ env_name (str): name of environment
+
+ render (bool): if True, environment supports on-screen rendering
+
+ render_offscreen (bool): if True, environment supports off-screen rendering. This
+ is forced to be True if @use_image_obs is True.
+
+ use_image_obs (bool): if True, environment is expected to render rgb image observations
+ on every env.step call. Set this to False for efficiency reasons, if image
+ observations are not required.
+
+ use_depth_obs (bool): if True, environment is expected to render depth image observations
+ on every env.step call. Set this to False for efficiency reasons, if depth
+ observations are not required.
+ """
+
+ # note: pass @postprocess_visual_obs True, to make sure images are processed for network inputs
+ if env_class is None:
+ env_class = get_env_class(env_type=env_type)
+ env = env_class(
+ env_name=env_name,
+ render=render,
+ render_offscreen=render_offscreen,
+ use_image_obs=use_image_obs,
+ use_depth_obs=use_depth_obs,
+ postprocess_visual_obs=True,
+ **kwargs,
+ )
+ print("Created environment with name {}".format(env_name))
+ print("Action size is {}".format(env.action_dimension))
+ return env
+
+
+def create_env_from_metadata(
+ env_meta,
+ env_name=None,
+ env_class=None,
+ render=False,
+ render_offscreen=False,
+ use_image_obs=False,
+ use_depth_obs=False,
+):
+ """
+ Create environment.
+
+ Args:
+ env_meta (dict): environment metadata, which should be loaded from demonstration
+ hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
+ @FileUtils.env_from_checkpoint). Contains 3 keys:
+
+ :`'env_name'`: name of environment
+ :`'type'`: type of environment, should be a value in EB.EnvType
+ :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
+
+ env_name (str): name of environment. Only needs to be provided if making a different
+ environment from the one in @env_meta.
+
+ render (bool): if True, environment supports on-screen rendering
+
+ render_offscreen (bool): if True, environment supports off-screen rendering. This
+ is forced to be True if @use_image_obs is True.
+
+ use_image_obs (bool): if True, environment is expected to render rgb image observations
+ on every env.step call. Set this to False for efficiency reasons, if image
+ observations are not required.
+
+ use_depth_obs (bool): if True, environment is expected to render depth image observations
+ on every env.step call. Set this to False for efficiency reasons, if depth
+ observations are not required.
+ """
+ if env_name is None:
+ env_name = env_meta["env_name"]
+ env_type = get_env_type(env_meta=env_meta)
+ env_kwargs = env_meta["env_kwargs"]
+ env_kwargs.pop("use_image_obs", None)
+ env_kwargs.pop("use_depth_obs", None)
+
+ env = create_env(
+ env_type=env_type,
+ env_name=env_name,
+ env_class=env_class,
+ render=render,
+ render_offscreen=render_offscreen,
+ use_image_obs=use_image_obs,
+ use_depth_obs=use_depth_obs,
+ **env_kwargs,
+ )
+ check_env_version(env, env_meta)
+ return env
+
+
+def create_env_for_data_processing(
+ env_meta,
+ camera_names,
+ camera_height,
+ camera_width,
+ reward_shaping,
+ env_class=None,
+ render=None,
+ render_offscreen=None,
+ use_image_obs=None,
+ use_depth_obs=None,
+):
+ """
+ Creates environment for processing dataset observations and rewards.
+
+ Args:
+ env_meta (dict): environment metadata, which should be loaded from demonstration
+ hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
+ @FileUtils.env_from_checkpoint). Contains 3 keys:
+
+ :`'env_name'`: name of environment
+ :`'type'`: type of environment, should be a value in EB.EnvType
+ :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
+
+ camera_names (list of st): list of camera names that correspond to image observations
+
+ camera_height (int): camera height for all cameras
+
+ camera_width (int): camera width for all cameras
+
+ reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards
+
+ render (bool or None): optionally override rendering behavior
+
+ render_offscreen (bool or None): optionally override rendering behavior
+
+ use_image_obs (bool or None): optionally override rendering behavior
+
+ use_depth_obs (bool or None): optionally override rendering behavior
+ """
+ env_name = env_meta["env_name"]
+ env_type = get_env_type(env_meta=env_meta)
+ env_kwargs = env_meta["env_kwargs"]
+ if env_class is None:
+ env_class = get_env_class(env_type=env_type)
+
+ # remove possibly redundant values in kwargs
+ env_kwargs = deepcopy(env_kwargs)
+ env_kwargs.pop("env_name", None)
+ env_kwargs.pop("camera_names", None)
+ env_kwargs.pop("camera_height", None)
+ env_kwargs.pop("camera_width", None)
+ env_kwargs.pop("reward_shaping", None)
+ env_kwargs.pop("render", None)
+ env_kwargs.pop("render_offscreen", None)
+ env_kwargs.pop("use_image_obs", None)
+ env_kwargs.pop("use_depth_obs", None)
+
+ env = env_class.create_for_data_processing(
+ env_name=env_name,
+ camera_names=camera_names,
+ camera_height=camera_height,
+ camera_width=camera_width,
+ reward_shaping=reward_shaping,
+ render=render,
+ render_offscreen=render_offscreen,
+ use_image_obs=use_image_obs,
+ use_depth_obs=use_depth_obs,
+ **env_kwargs,
+ )
+ check_env_version(env, env_meta)
+ return env
+
+
+def set_env_specific_obs_processing(env_meta=None, env_type=None, env=None):
+ """
+ Sets env-specific observation processing. As an example, robosuite depth observations
+ correspond to raw depth and should not be normalized by default, while default depth
+ processing normalizes and clips all values to [0, 1].
+ """
+ if is_robosuite_env(env_meta=env_meta, env_type=env_type, env=env):
+ from robomimic.utils.obs_utils import DepthModality, process_frame, unprocess_frame
+ DepthModality.set_obs_processor(processor=(
+ lambda obs: process_frame(frame=obs, channel_dim=1, scale=None)
+ ))
+ DepthModality.set_obs_unprocessor(unprocessor=(
+ lambda obs: unprocess_frame(frame=obs, channel_dim=1, scale=None)
+ ))
+
+
+def wrap_env_from_config(env, config):
+ """
+ Wraps environment using the provided Config object to determine which wrappers
+ to use (if any).
+ """
+ if ("frame_stack" in config.train) and (config.train.frame_stack > 1):
+ from robomimic.envs.wrappers import FrameStackWrapper
+ env = FrameStackWrapper(env, num_frames=config.train.frame_stack)
+
+ return env
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/file_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e43d07f96ea354e9150c5dd90721139371175c
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/file_utils.py
@@ -0,0 +1,616 @@
+"""
+A collection of utility functions for working with files, such as reading metadata from
+demonstration datasets, loading model checkpoints, or downloading dataset files.
+"""
+import os
+import h5py
+import json
+import time
+import urllib.request
+import numpy as np
+from collections import OrderedDict
+from tqdm import tqdm
+
+import torch
+
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.env_utils as EnvUtils
+import robomimic.utils.torch_utils as TorchUtils
+from robomimic.config import config_factory
+from robomimic.algo import algo_factory
+from robomimic.algo import RolloutPolicy
+
+
+def create_hdf5_filter_key(hdf5_path, demo_keys, key_name):
+ """
+ Creates a new hdf5 filter key in hdf5 file @hdf5_path with
+ name @key_name that corresponds to the demonstrations
+ @demo_keys. Filter keys are generally useful to create
+ named subsets of the demonstrations in an hdf5, making it
+ easy to train, test, or report statistics on a subset of
+ the trajectories in a file.
+
+ Returns the list of episode lengths that correspond to the filtering.
+
+ Args:
+ hdf5_path (str): path to hdf5 file
+ demo_keys ([str]): list of demonstration keys which should
+ correspond to this filter key. For example, ["demo_0",
+ "demo_1"].
+ key_name (str): name of filter key to create
+
+ Returns:
+ ep_lengths ([int]): list of episode lengths that corresponds to
+ each demonstration in the new filter key
+ """
+ f = h5py.File(hdf5_path, "a")
+ demos = sorted(list(f["data"].keys()))
+
+ # collect episode lengths for the keys of interest
+ ep_lengths = []
+ for ep in demos:
+ ep_data_grp = f["data/{}".format(ep)]
+ if ep in demo_keys:
+ ep_lengths.append(ep_data_grp.attrs["num_samples"])
+
+ # store list of filtered keys under mask group
+ k = "mask/{}".format(key_name)
+ if k in f:
+ del f[k]
+ f[k] = np.array(demo_keys, dtype='S')
+
+ f.close()
+ return ep_lengths
+
+
+def get_demos_for_filter_key(hdf5_path, filter_key):
+ """
+ Gets demo keys that correspond to a particular filter key.
+
+ Args:
+ hdf5_path (str): path to hdf5 file
+ filter_key (str): name of filter key
+
+ Returns:
+ demo_keys ([str]): list of demonstration keys that
+ correspond to this filter key. For example, ["demo_0",
+ "demo_1"].
+ """
+ f = h5py.File(hdf5_path, "r")
+ demo_keys = [elem.decode("utf-8") for elem in np.array(f["mask/{}".format(filter_key)][:])]
+ f.close()
+ return demo_keys
+
+
+def get_env_metadata_from_dataset(dataset_path, ds_format="robomimic", set_env_specific_obs_processors=True):
+ """
+ Retrieves env metadata from dataset.
+
+ Args:
+ dataset_path (str): path to dataset
+
+ set_env_specific_obs_processors (bool): environment might have custom rules for how to process
+ observations - if this flag is true, make sure ObsUtils will use these custom settings. This
+ is a good place to do this operation to make sure it happens before loading data, running a
+ trained model, etc.
+
+ Returns:
+ env_meta (dict): environment metadata. Contains 3 keys:
+
+ :`'env_name'`: name of environment
+ :`'type'`: type of environment, should be a value in EB.EnvType
+ :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
+ """
+ dataset_path = os.path.expandvars(os.path.expanduser(dataset_path))
+ f = h5py.File(dataset_path, "r")
+ if ds_format == "robomimic":
+ env_meta = json.loads(f["data"].attrs["env_args"])
+ elif ds_format == "r2d2":
+ env_meta = dict(f.attrs)
+ else:
+ raise ValueError
+ f.close()
+ if set_env_specific_obs_processors:
+ # handle env-specific custom observation processing logic
+ EnvUtils.set_env_specific_obs_processing(env_meta=env_meta)
+ return env_meta
+
+
+def get_shape_metadata_from_dataset(dataset_path, action_keys, all_obs_keys=None, ds_format="robomimic", verbose=False):
+ """
+ Retrieves shape metadata from dataset.
+
+ Args:
+ dataset_path (str): path to dataset
+ action_keys (list): list of all action key strings
+ all_obs_keys (list): list of all modalities used by the model. If not provided, all modalities
+ present in the file are used.
+ verbose (bool): if True, include print statements
+
+ Returns:
+ shape_meta (dict): shape metadata. Contains the following keys:
+
+ :`'ac_dim'`: action space dimension
+ :`'all_shapes'`: dictionary that maps observation key string to shape
+ :`'all_obs_keys'`: list of all observation modalities used
+ :`'use_images'`: bool, whether or not image modalities are present
+ :`'use_depths'`: bool, whether or not depth modalities are present
+ """
+
+ shape_meta = {}
+
+ # read demo file for some metadata
+ dataset_path = os.path.expandvars(os.path.expanduser(dataset_path))
+ f = h5py.File(dataset_path, "r")
+
+ if ds_format == "robomimic":
+ demo_id = list(f["data"].keys())[0]
+ demo = f["data/{}".format(demo_id)]
+
+ for key in action_keys:
+ assert len(demo[key].shape) == 2 # shape should be (B, D)
+ action_dim = sum([demo[key].shape[1] for key in action_keys])
+ shape_meta["ac_dim"] = action_dim
+
+ # observation dimensions
+ all_shapes = OrderedDict()
+
+ if all_obs_keys is None:
+ # use all modalities present in the file
+ all_obs_keys = [k for k in demo["obs"]]
+
+ for k in sorted(all_obs_keys):
+ initial_shape = demo["obs/{}".format(k)].shape[1:]
+ if verbose:
+ print("obs key {} with shape {}".format(k, initial_shape))
+ # Store processed shape for each obs key
+ all_shapes[k] = ObsUtils.get_processed_shape(
+ obs_modality=ObsUtils.OBS_KEYS_TO_MODALITIES[k],
+ input_shape=initial_shape,
+ )
+ elif ds_format == "r2d2":
+ for key in action_keys:
+ assert len(f[key].shape) == 2 # shape should be (B, D)
+ action_dim = sum([f[key].shape[1] for key in action_keys])
+ shape_meta["ac_dim"] = action_dim
+
+ # observation dimensions
+ all_shapes = OrderedDict()
+
+ # hack all relevant obs shapes for now
+ for k in [
+ "robot_state/cartesian_position",
+ "robot_state/gripper_position",
+ "robot_state/joint_positions",
+ "camera/image/hand_camera_image",
+ "camera/image/varied_camera_1_image",
+ "camera/image/varied_camera_2_image",
+ ]:
+ initial_shape = f["observation/{}".format(k)].shape[1:]
+ if len(initial_shape) == 0:
+ initial_shape = (1,)
+
+ all_shapes[k] = ObsUtils.get_processed_shape(
+ obs_modality=ObsUtils.OBS_KEYS_TO_MODALITIES[k],
+ input_shape=initial_shape,
+ )
+ else:
+ raise ValueError
+
+ f.close()
+
+ shape_meta['all_shapes'] = all_shapes
+ shape_meta['all_obs_keys'] = all_obs_keys
+ shape_meta['use_images'] = ObsUtils.has_modality("rgb", all_obs_keys)
+ shape_meta['use_depths'] = ObsUtils.has_modality("depth", all_obs_keys)
+
+ return shape_meta
+
+
+def get_intervention_segments(interventions):
+ """
+ Splits interventions list into a list of start and end indices (windows) of continuous intervention segments.
+ """
+ interventions = interventions.reshape(-1).astype(int)
+ # pad before and after to make it easy to count starting and ending intervention segments
+ expanded_ints = [False] + interventions.astype(bool).tolist() + [False]
+ start_inds = []
+ end_inds = []
+ for i in range(1, len(expanded_ints)):
+ if expanded_ints[i] and (not expanded_ints[i - 1]):
+ # low to high edge means start of new window
+ start_inds.append(i - 1) # record index in original array which is one less (since we added an element to the beg)
+ elif (not expanded_ints[i]) and expanded_ints[i - 1]:
+ # high to low edge means end of previous window
+ end_inds.append(i - 1) # record index in original array which is one less (since we added an element to the beg)
+
+ # run some sanity checks
+ assert len(start_inds) == len(end_inds), "missing window edge"
+ assert np.all([np.sum(interventions[s : e]) == (e - s) for s, e in zip(start_inds, end_inds)]), "window computation covers non-interventions"
+ assert sum([np.sum(interventions[s : e]) for s, e in zip(start_inds, end_inds)]) == np.sum(interventions), "window computation does not cover all interventions"
+ return list(zip(start_inds, end_inds))
+
+
+def load_dict_from_checkpoint(ckpt_path):
+ """
+ Load checkpoint dictionary from a checkpoint file.
+
+ Args:
+ ckpt_path (str): Path to checkpoint file.
+
+ Returns:
+ ckpt_dict (dict): Loaded checkpoint dictionary.
+ """
+ ckpt_path = os.path.expandvars(os.path.expanduser(ckpt_path))
+ if not torch.cuda.is_available():
+ ckpt_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
+ else:
+ ckpt_dict = torch.load(ckpt_path)
+ return ckpt_dict
+
+
+def maybe_dict_from_checkpoint(ckpt_path=None, ckpt_dict=None):
+ """
+ Utility function for the common use case where either an ckpt path
+ or a ckpt_dict is provided. This is a no-op if ckpt_dict is not
+ None, otherwise it loads the model dict from the ckpt path.
+
+ Args:
+ ckpt_path (str): Path to checkpoint file. Only needed if not providing @ckpt_dict.
+
+ ckpt_dict(dict): Loaded model checkpoint dictionary. Only needed if not providing @ckpt_path.
+
+ Returns:
+ ckpt_dict (dict): Loaded checkpoint dictionary.
+ """
+ assert (ckpt_path is not None) or (ckpt_dict is not None)
+ if ckpt_dict is None:
+ ckpt_dict = load_dict_from_checkpoint(ckpt_path)
+ return ckpt_dict
+
+
+def algo_name_from_checkpoint(ckpt_path=None, ckpt_dict=None):
+ """
+ Return algorithm name that was used to train a checkpoint or
+ loaded model dictionary.
+
+ Args:
+ ckpt_path (str): Path to checkpoint file. Only needed if not providing @ckpt_dict.
+
+ ckpt_dict(dict): Loaded model checkpoint dictionary. Only needed if not providing @ckpt_path.
+
+ Returns:
+ algo_name (str): algorithm name
+
+ ckpt_dict (dict): loaded checkpoint dictionary (convenient to avoid
+ re-loading checkpoint from disk multiple times)
+ """
+ ckpt_dict = maybe_dict_from_checkpoint(ckpt_path=ckpt_path, ckpt_dict=ckpt_dict)
+ algo_name = ckpt_dict["algo_name"]
+ return algo_name, ckpt_dict
+
+
+def update_config(cfg):
+ """
+ Updates the config for backwards-compatibility if it uses outdated configurations.
+
+ See https://github.com/ARISE-Initiative/robomimic/releases/tag/v0.2.0 for more info.
+
+ Args:
+ cfg (dict): Raw dictionary of config values
+ """
+ # Check if image modality is defined -- this means we're using an outdated config
+ # Note: There may be a nested hierarchy, so we possibly check all the nested obs cfgs which can include
+ # e.g. a planner and actor for HBC
+
+ def find_obs_dicts_recursively(dic):
+ dics = []
+ if "modalities" in dic:
+ dics.append(dic)
+ else:
+ for child_dic in dic.values():
+ dics += find_obs_dicts_recursively(child_dic)
+ return dics
+
+ obs_cfgs = find_obs_dicts_recursively(cfg["observation"])
+ for obs_cfg in obs_cfgs:
+ modalities = obs_cfg["modalities"]
+
+ found_img = False
+ for modality_group in ("obs", "subgoal", "goal"):
+ if modality_group in modalities:
+ img_modality = modalities[modality_group].pop("image", None)
+ if img_modality is not None:
+ found_img = True
+ modalities[modality_group]["rgb"] = img_modality
+
+ if found_img:
+ # Also need to map encoder kwargs correctly
+ old_encoder_cfg = obs_cfg.pop("encoder")
+
+ # Create new encoder entry for RGB
+ rgb_encoder_cfg = {
+ "core_class": "VisualCore",
+ "core_kwargs": {
+ "backbone_kwargs": dict(),
+ "pool_kwargs": dict(),
+ },
+ "obs_randomizer_class": None,
+ "obs_randomizer_kwargs": dict(),
+ }
+
+ if "visual_feature_dimension" in old_encoder_cfg:
+ rgb_encoder_cfg["core_kwargs"]["feature_dimension"] = old_encoder_cfg["visual_feature_dimension"]
+
+ if "visual_core" in old_encoder_cfg:
+ rgb_encoder_cfg["core_kwargs"]["backbone_class"] = old_encoder_cfg["visual_core"]
+
+ for kwarg in ("pretrained", "input_coord_conv"):
+ if "visual_core_kwargs" in old_encoder_cfg and kwarg in old_encoder_cfg["visual_core_kwargs"]:
+ rgb_encoder_cfg["core_kwargs"]["backbone_kwargs"][kwarg] = old_encoder_cfg["visual_core_kwargs"][kwarg]
+
+ # Optionally add pooling info too
+ if old_encoder_cfg.get("use_spatial_softmax", True):
+ rgb_encoder_cfg["core_kwargs"]["pool_class"] = "SpatialSoftmax"
+
+ for kwarg in ("num_kp", "learnable_temperature", "temperature", "noise_std"):
+ if "spatial_softmax_kwargs" in old_encoder_cfg and kwarg in old_encoder_cfg["spatial_softmax_kwargs"]:
+ rgb_encoder_cfg["core_kwargs"]["pool_kwargs"][kwarg] = old_encoder_cfg["spatial_softmax_kwargs"][kwarg]
+
+ # Update obs randomizer as well
+ for kwarg in ("obs_randomizer_class", "obs_randomizer_kwargs"):
+ if kwarg in old_encoder_cfg:
+ rgb_encoder_cfg[kwarg] = old_encoder_cfg[kwarg]
+
+ # Store rgb config
+ obs_cfg["encoder"] = {"rgb": rgb_encoder_cfg}
+
+ # Also add defaults for low dim
+ obs_cfg["encoder"]["low_dim"] = {
+ "core_class": None,
+ "core_kwargs": {
+ "backbone_kwargs": dict(),
+ "pool_kwargs": dict(),
+ },
+ "obs_randomizer_class": None,
+ "obs_randomizer_kwargs": dict(),
+ }
+
+
+def config_from_checkpoint(algo_name=None, ckpt_path=None, ckpt_dict=None, verbose=False):
+ """
+ Helper function to restore config from a checkpoint file or loaded model dictionary.
+
+ Args:
+ algo_name (str): Algorithm name.
+
+ ckpt_path (str): Path to checkpoint file. Only needed if not providing @ckpt_dict.
+
+ ckpt_dict(dict): Loaded model checkpoint dictionary. Only needed if not providing @ckpt_path.
+
+ verbose (bool): if True, include print statements
+
+ Returns:
+ config (dict): Raw loaded configuration, without properties replaced.
+
+ ckpt_dict (dict): loaded checkpoint dictionary (convenient to avoid
+ re-loading checkpoint from disk multiple times)
+ """
+ ckpt_dict = maybe_dict_from_checkpoint(ckpt_path=ckpt_path, ckpt_dict=ckpt_dict)
+ if algo_name is None:
+ algo_name, _ = algo_name_from_checkpoint(ckpt_dict=ckpt_dict)
+
+ # restore config from loaded model dictionary
+ config_dict = json.loads(ckpt_dict['config'])
+ update_config(cfg=config_dict)
+
+ if verbose:
+ print("============= Loaded Config =============")
+ print(json.dumps(config_dict, indent=4))
+
+ config = config_factory(algo_name, dic=config_dict)
+
+ # lock config to prevent further modifications and ensure missing keys raise errors
+ config.lock()
+
+ return config, ckpt_dict
+
+
+def policy_from_checkpoint(device=None, ckpt_path=None, ckpt_dict=None, verbose=False):
+ """
+ This function restores a trained policy from a checkpoint file or
+ loaded model dictionary.
+
+ Args:
+ device (torch.device): if provided, put model on this device
+
+ ckpt_path (str): Path to checkpoint file. Only needed if not providing @ckpt_dict.
+
+ ckpt_dict(dict): Loaded model checkpoint dictionary. Only needed if not providing @ckpt_path.
+
+ verbose (bool): if True, include print statements
+
+ Returns:
+ model (RolloutPolicy): instance of Algo that has the saved weights from
+ the checkpoint file, and also acts as a policy that can easily
+ interact with an environment in a training loop
+
+ ckpt_dict (dict): loaded checkpoint dictionary (convenient to avoid
+ re-loading checkpoint from disk multiple times)
+ """
+ ckpt_dict = maybe_dict_from_checkpoint(ckpt_path=ckpt_path, ckpt_dict=ckpt_dict)
+
+ # algo name and config from model dict
+ algo_name, _ = algo_name_from_checkpoint(ckpt_dict=ckpt_dict)
+ config, _ = config_from_checkpoint(algo_name=algo_name, ckpt_dict=ckpt_dict, verbose=verbose)
+
+ # read config to set up metadata for observation modalities (e.g. detecting rgb observations)
+ ObsUtils.initialize_obs_utils_with_config(config)
+
+ # shape meta from model dict to get info needed to create model
+ shape_meta = ckpt_dict["shape_metadata"]
+
+ # maybe restore observation normalization stats
+ obs_normalization_stats = ckpt_dict.get("obs_normalization_stats", None)
+ if obs_normalization_stats is not None:
+ assert config.train.hdf5_normalize_obs
+ for m in obs_normalization_stats:
+ for k in obs_normalization_stats[m]:
+ obs_normalization_stats[m][k] = np.array(obs_normalization_stats[m][k])
+
+ # maybe restore action normalization stats
+ action_normalization_stats = ckpt_dict.get("action_normalization_stats", None)
+ if action_normalization_stats is not None:
+ for m in action_normalization_stats:
+ for k in action_normalization_stats[m]:
+ action_normalization_stats[m][k] = np.array(action_normalization_stats[m][k])
+
+ if device is None:
+ # get torch device
+ device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda)
+
+ # create model and load weights
+ model = algo_factory(
+ algo_name,
+ config,
+ obs_key_shapes=shape_meta["all_shapes"],
+ ac_dim=shape_meta["ac_dim"],
+ device=device,
+ )
+ model.deserialize(ckpt_dict["model"])
+ model.set_eval()
+ model = RolloutPolicy(
+ model,
+ obs_normalization_stats=obs_normalization_stats,
+ action_normalization_stats=action_normalization_stats
+ )
+ if verbose:
+ print("============= Loaded Policy =============")
+ print(model)
+ return model, ckpt_dict
+
+
+def env_from_checkpoint(ckpt_path=None, ckpt_dict=None, env_name=None, render=False, render_offscreen=False, verbose=False):
+ """
+ Creates an environment using the metadata saved in a checkpoint.
+
+ Args:
+ ckpt_path (str): Path to checkpoint file. Only needed if not providing @ckpt_dict.
+
+ ckpt_dict(dict): Loaded model checkpoint dictionary. Only needed if not providing @ckpt_path.
+
+ env_name (str): if provided, override environment name saved in checkpoint
+
+ render (bool): if True, environment supports on-screen rendering
+
+ render_offscreen (bool): if True, environment supports off-screen rendering. This
+ is forced to be True if saved model uses image observations.
+
+ Returns:
+ env (EnvBase instance): environment created using checkpoint
+
+ ckpt_dict (dict): loaded checkpoint dictionary (convenient to avoid
+ re-loading checkpoint from disk multiple times)
+ """
+ ckpt_dict = maybe_dict_from_checkpoint(ckpt_path=ckpt_path, ckpt_dict=ckpt_dict)
+
+ # metadata from model dict to get info needed to create environment
+ env_meta = ckpt_dict["env_metadata"]
+ shape_meta = ckpt_dict["shape_metadata"]
+
+ # create env from saved metadata
+ env = EnvUtils.create_env_from_metadata(
+ env_meta=env_meta,
+ env_name=env_name,
+ render=render,
+ render_offscreen=render_offscreen,
+ use_image_obs=shape_meta.get("use_images", False),
+ use_depth_obs=shape_meta.get("use_depths", False),
+ )
+ config, _ = config_from_checkpoint(algo_name=ckpt_dict["algo_name"], ckpt_dict=ckpt_dict, verbose=False)
+ env = EnvUtils.wrap_env_from_config(env, config=config) # apply environment wrapper, if applicable
+ if verbose:
+ print("============= Loaded Environment =============")
+ print(env)
+ return env, ckpt_dict
+
+
+class DownloadProgressBar(tqdm):
+ def update_to(self, b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ self.total = tsize
+ self.update(b * bsize - self.n)
+
+
+def url_is_alive(url):
+ """
+ Checks that a given URL is reachable.
+ From https://gist.github.com/dehowell/884204.
+
+ Args:
+ url (str): url string
+
+ Returns:
+ is_alive (bool): True if url is reachable, False otherwise
+ """
+ request = urllib.request.Request(url)
+ request.get_method = lambda: 'HEAD'
+
+ try:
+ urllib.request.urlopen(request)
+ return True
+ except urllib.request.HTTPError:
+ return False
+
+
+def download_url(url, download_dir, check_overwrite=True):
+ """
+ First checks that @url is reachable, then downloads the file
+ at that url into the directory specified by @download_dir.
+ Prints a progress bar during the download using tqdm.
+
+ Modified from https://github.com/tqdm/tqdm#hooks-and-callbacks, and
+ https://stackoverflow.com/a/53877507.
+
+ Args:
+ url (str): url string
+ download_dir (str): path to directory where file should be downloaded
+ check_overwrite (bool): if True, will sanity check the download fpath to make sure a file of that name
+ doesn't already exist there
+ """
+
+ # check if url is reachable. We need the sleep to make sure server doesn't reject subsequent requests
+ assert url_is_alive(url), "@download_url got unreachable url: {}".format(url)
+ time.sleep(0.5)
+
+ # infer filename from url link
+ fname = url.split("/")[-1]
+ file_to_write = os.path.join(download_dir, fname)
+
+ # If we're checking overwrite and the path already exists,
+ # we ask the user to verify that they want to overwrite the file
+ if check_overwrite and os.path.exists(file_to_write):
+ user_response = input(f"Warning: file {file_to_write} already exists. Overwrite? y/n\n")
+ assert user_response.lower() in {"yes", "y"}, f"Did not receive confirmation. Aborting download."
+
+ with DownloadProgressBar(unit='B', unit_scale=True,
+ miniters=1, desc=fname) as t:
+ urllib.request.urlretrieve(url, filename=file_to_write, reporthook=t.update_to)
+
+
+def find_and_replace_path_prefix(org_path, replace_prefixes, new_prefix, assert_replace=False):
+ """
+ Try to find and replace one of several prefixes (@replace_prefixes) in string @org_path
+ with another prefix (@new_prefix). If @assert_replace is True, the function asserts that
+ replacement did occur.
+ """
+ check_ind = -1
+ for i, x in enumerate(replace_prefixes):
+ if org_path.startswith(x):
+ check_ind = i
+ if assert_replace:
+ assert check_ind != -1
+ if check_ind == -1:
+ return org_path
+ replace_prefix = replace_prefixes[check_ind]
+ return org_path.replace(replace_prefix, new_prefix, 1)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/hyperparam_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/hyperparam_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..460cfbc9e6b4f43d36fde6cdf440332307da7af2
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/hyperparam_utils.py
@@ -0,0 +1,368 @@
+"""
+A collection of utility functions and classes for generating config jsons for hyperparameter sweeps.
+"""
+import argparse
+import os
+import json
+import re
+import itertools
+
+from collections import OrderedDict
+from copy import deepcopy
+
+
+class ConfigGenerator(object):
+ """
+ Useful class to keep track of hyperparameters to sweep, and to generate
+ the json configs for each experiment run.
+ """
+ def __init__(self, base_config_file, wandb_proj_name="debug", script_file=None, generated_config_dir=None):
+ """
+ Args:
+ base_config_file (str): path to a base json config to use as a starting point
+ for the parameter sweep.
+
+ script_file (str): script filename to write as output
+ """
+ assert isinstance(base_config_file, str)
+ self.base_config_file = base_config_file
+ assert generated_config_dir is None or isinstance(generated_config_dir, str)
+ if generated_config_dir is not None:
+ generated_config_dir = os.path.expanduser(generated_config_dir)
+ self.generated_config_dir = generated_config_dir
+ assert script_file is None or isinstance(script_file, str)
+ if script_file is None:
+ self.script_file = os.path.join('~', 'tmp/tmpp.sh')
+ else:
+ self.script_file = script_file
+ self.script_file = os.path.expanduser(self.script_file)
+ self.parameters = OrderedDict()
+
+ assert isinstance(wandb_proj_name, str)
+ self.wandb_proj_name = wandb_proj_name
+
+ def add_param(self, key, name, group, values, value_names=None):
+ """
+ Add parameter to the hyperparameter sweep.
+
+ Args:
+ key (str): location of parameter in the config, using hierarchical key format
+ (ex. train/data = config.train.data)
+
+ name (str): name, as it will appear in the experiment name
+
+ group (int): group id - parameters with the same ID have their values swept
+ together
+
+ values (list): list of values to sweep over for this parameter
+
+ value_names ([str]): if provided, strings to use in experiment name for
+ each value, instead of the parameter value. This is helpful for parameters
+ that may have long or large values (for example, dataset path).
+ """
+ if value_names is not None:
+ assert len(values) == len(value_names)
+ self.parameters[key] = argparse.Namespace(
+ key=key,
+ name=name,
+ group=group,
+ values=values,
+ value_names=value_names,
+ hidename=hidename,
+ )
+
+ def generate(self):
+ """
+ Generates json configs for the hyperparameter sweep using attributes
+ @self.parameters, @self.base_config_file, and @self.script_file,
+ all of which should have first been set externally by calling
+ @add_param, @set_base_config_file, and @set_script_file.
+ """
+ assert len(self.parameters) > 0, "must add parameters using add_param first!"
+ generated_json_paths = self._generate_jsons()
+ self._script_from_jsons(generated_json_paths)
+
+ def _name_for_experiment(self, base_name, parameter_values, parameter_value_names):
+ """
+ This function generates the name for an experiment, given one specific
+ parameter setting.
+
+ Args:
+ base_name (str): base experiment name
+ parameter_values (OrderedDict): dictionary that maps parameter name to
+ the parameter value for this experiment run
+ parameter_value_names (dict): dictionary that maps parameter name to
+ the name to use for its value in the experiment name
+
+ Returns:
+ name (str): generated experiment name
+ """
+ name = base_name
+ for k in parameter_values:
+ # append parameter name and value to end of base name
+ if len(self.parameters[k].name) == 0 or self.parameters[k].hidename:
+ # empty string indicates that naming should be skipped
+ continue
+ if len(self.parameters[k].name) == 0:
+ # empty string indicates that naming should be skipped
+ continue
+ if parameter_value_names[k] is not None:
+ # take name from passed dictionary
+ val_str = parameter_value_names[k]
+ else:
+ val_str = parameter_values[k]
+ if isinstance(parameter_values[k], list) or isinstance(parameter_values[k], tuple):
+ # convert list to string to avoid weird spaces and naming problems
+ val_str = "_".join([str(x) for x in parameter_values[k]])
+ val_str = str(val_str)
+ name += '_{}'.format(self.parameters[k].name)
+ if len(val_str) > 0:
+ name += '_{}'.format(val_str)
+ return name
+
+ def _get_parameter_ranges(self):
+ """
+ Extract parameter ranges from base json file. Also takes all possible
+ combinations of the parameter ranges to generate an expanded set of values.
+
+ Returns:
+ parameter_ranges (dict): dictionary that maps the parameter to a list
+ of all values it should take for each generated config. The length
+ of the list will be the total number of configs that will be
+ generated from this scan.
+
+ parameter_names (dict): dictionary that maps the parameter to a list
+ of all name strings that should contribute to each invididual
+ experiment's name. The length of the list will be the total
+ number of configs that will be generated from this scan.
+ """
+
+ # mapping from group id to list of indices to grab from each parameter's list
+ # of values in the parameter group
+ parameter_group_indices = OrderedDict()
+ for k in self.parameters:
+ group_id = self.parameters[k].group
+ assert isinstance(self.parameters[k].values, list)
+ num_param_values = len(self.parameters[k].values)
+ if group_id not in parameter_group_indices:
+ parameter_group_indices[group_id] = list(range(num_param_values))
+ else:
+ assert len(parameter_group_indices[group_id]) == num_param_values, \
+ "error: inconsistent number of parameter values in group with id {}".format(group_id)
+
+ keys = list(parameter_group_indices.keys())
+ inds = list(parameter_group_indices.values())
+ new_parameter_group_indices = OrderedDict(
+ { k : [] for k in keys }
+ )
+ # get all combinations of the different parameter group indices
+ # and then use these indices to determine the new parameter ranges
+ # per member of each parameter group.
+ #
+ # e.g. with two parameter groups, one with two values, and another with three values
+ # we have [0, 1] x [0, 1, 2] = [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
+ # so the corresponding parameter group indices are [0, 0, 0, 1, 1, 1] and
+ # [0, 1, 2, 0, 1, 2], and all parameters in each parameter group are indexed
+ # together using these indices, to get each parameter range.
+ for comb in itertools.product(*inds):
+ for i in range(len(comb)):
+ new_parameter_group_indices[keys[i]].append(comb[i])
+ parameter_group_indices = new_parameter_group_indices
+
+ # use the indices to gather the parameter values to sweep per parameter
+ parameter_ranges = OrderedDict()
+ parameter_names = OrderedDict()
+ for k in self.parameters:
+ parameter_values = self.parameters[k].values
+ group_id = self.parameters[k].group
+ inds = parameter_group_indices[group_id]
+ parameter_ranges[k] = [parameter_values[ind] for ind in inds]
+
+ # add in parameter names if supplied
+ parameter_names[k] = None
+ if self.parameters[k].value_names is not None:
+ par_names = self.parameters[k].value_names
+ assert isinstance(par_names, list)
+ assert len(par_names) == len(parameter_values)
+ parameter_names[k] = [par_names[ind] for ind in inds]
+
+ # ensure that the number of parameter settings is the same per parameter
+ first_key = list(parameter_ranges.keys())[0]
+ num_settings = len(parameter_ranges[first_key])
+ for k in parameter_ranges:
+ assert len(parameter_ranges[k]) == num_settings, "inconsistent number of values"
+
+ return parameter_ranges, parameter_names
+
+ def _generate_jsons(self):
+ """
+ Generates json configs for the hyperparameter sweep, using @self.parameters and
+ @self.base_config_file.
+
+ Returns:
+ json_paths (list): list of paths to created json files, one per experiment
+ """
+
+ # base directory for saving jsons
+ if self.generated_config_dir:
+ base_dir = self.generated_config_dir
+ if not os.path.exists(base_dir):
+ os.makedirs(base_dir)
+ else:
+ base_dir = os.path.abspath(os.path.dirname(self.base_config_file))
+
+ # read base json
+ base_config = load_json(self.base_config_file, verbose=False)
+
+ # base exp name from this base config
+ base_exp_name = base_config['experiment']['name']
+
+ # use base json to determine the parameter ranges
+ parameter_ranges, parameter_names = self._get_parameter_ranges()
+
+ # iterate through each parameter setting to create each json
+ first_key = list(parameter_ranges.keys())[0]
+ num_settings = len(parameter_ranges[first_key])
+
+ # keep track of path to generated jsons
+ json_paths = []
+
+ for i in range(num_settings):
+ # the specific parameter setting for this experiment
+ setting = { k : parameter_ranges[k][i] for k in parameter_ranges }
+ maybe_parameter_names = OrderedDict()
+ for k in parameter_names:
+ maybe_parameter_names[k] = None
+ if parameter_names[k] is not None:
+ maybe_parameter_names[k] = parameter_names[k][i]
+
+ # experiment name from setting
+ exp_name = self._name_for_experiment(
+ base_name=base_exp_name,
+ parameter_values=setting,
+ parameter_value_names=maybe_parameter_names,
+ )
+
+ # copy old json, but override name, and parameter values
+ json_dict = deepcopy(base_config)
+ json_dict['experiment']['name'] = exp_name
+ for k in parameter_ranges:
+ set_value_for_key(json_dict, k, v=parameter_ranges[k][i])
+
+ # populate list of identifying meta for logger;
+ # see meta_config method in base_config.py for more info
+ json_dict["experiment"]["logging"]["wandb_proj_name"] = self.wandb_proj_name
+ if "meta" not in json_dict:
+ json_dict["meta"] = dict()
+ json_dict["meta"].update(
+ hp_base_config_file=self.base_config_file,
+ hp_keys=list(),
+ hp_values=list(),
+ )
+ # logging: keep track of hyp param names and values as meta info
+ for k in parameter_ranges.keys():
+ key_name = self.parameters[k].name
+ if key_name is not None and len(key_name) > 0:
+ if maybe_parameter_names[k] is not None:
+ value_name = maybe_parameter_names[k]
+ else:
+ value_name = setting[k]
+
+ json_dict["meta"]["hp_keys"].append(key_name)
+ json_dict["meta"]["hp_values"].append(value_name)
+
+ # save file in same directory as old json
+ json_path = os.path.join(base_dir, "{}.json".format(exp_name))
+ save_json(json_dict, json_path)
+ json_paths.append(json_path)
+
+ print("Num exps:", len(json_paths))
+
+ return json_paths
+
+ def _script_from_jsons(self, json_paths):
+ """
+ Generates a bash script to run the experiments that correspond to
+ the input jsons.
+ """
+ with open(self.script_file, 'w') as f:
+ f.write("#!/bin/bash\n\n")
+ for path in json_paths:
+ # write python command to file
+ cmd = "python train.py --config {}\n".format(path)
+
+ print()
+ print(cmd)
+ f.write(cmd)
+
+
+def load_json(json_file, verbose=True):
+ """
+ Simple utility function to load a json file as a dict.
+
+ Args:
+ json_file (str): path to json file to load
+ verbose (bool): if True, pretty print the loaded json dictionary
+
+ Returns:
+ config (dict): json dictionary
+ """
+ with open(json_file, 'r') as f:
+ config = json.load(f)
+ if verbose:
+ print('loading external config: =================')
+ print(json.dumps(config, indent=4))
+ print('==========================================')
+ return config
+
+
+def save_json(config, json_file):
+ """
+ Simple utility function to save a dictionary to a json file on disk.
+
+ Args:
+ config (dict): dictionary to save
+ json_file (str): path to json file to write
+ """
+ with open(json_file, 'w') as f:
+ # preserve original key ordering
+ json.dump(config, f, sort_keys=False, indent=4)
+
+
+def get_value_for_key(dic, k):
+ """
+ Get value for nested dictionary with levels denoted by "/" or ".".
+ For example, if @k is "a/b", then this function returns
+ @dic["a"]["b"].
+
+ Args:
+ dic (dict): a nested dictionary
+ k (str): a single string meant to index several levels down into
+ the nested dictionary, where levels can be denoted by "/" or
+ by ".".
+ Returns:
+ val: the nested dictionary value for the provided key
+ """
+ val = dic
+ subkeys = re.split('/|\.', k)
+ for s in subkeys[:-1]:
+ val = val[s]
+ return val[subkeys[-1]]
+
+
+def set_value_for_key(dic, k, v):
+ """
+ Set value for hierarchical dictionary with levels denoted by "/" or ".".
+
+ Args:
+ dic (dict): a nested dictionary
+ k (str): a single string meant to index several levels down into
+ the nested dictionary, where levels can be denoted by "/" or
+ by ".".
+ v: the value to set at the provided key
+ """
+ val = dic
+ subkeys = re.split('/|\.', k) #k.split('/')
+ for s in subkeys[:-1]:
+ val = val[s]
+ val[subkeys[-1]] = v
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/log_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/log_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b431d5d4988dfdd60135eb3b81319fc825ba223a
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/log_utils.py
@@ -0,0 +1,230 @@
+"""
+This file contains utility classes and functions for logging to stdout, stderr,
+and to tensorboard.
+"""
+import os
+import sys
+import numpy as np
+from datetime import datetime
+from contextlib import contextmanager
+import textwrap
+import time
+from tqdm import tqdm
+from termcolor import colored
+
+import robomimic
+
+# global list of warning messages can be populated with @log_warning and flushed with @flush_warnings
+WARNINGS_BUFFER = []
+
+
+class PrintLogger(object):
+ """
+ This class redirects print statements to both console and a file.
+ """
+ def __init__(self, log_file):
+ self.terminal = sys.stdout
+ print('STDOUT will be forked to %s' % log_file)
+ self.log_file = open(log_file, "a")
+
+ def fileno(self):
+ return self.terminal.fileno()
+
+ def write(self, message):
+ self.terminal.write(message)
+ self.log_file.write(message)
+ self.log_file.flush()
+
+ def flush(self):
+ # this flush method is needed for python 3 compatibility.
+ # this handles the flush command by doing nothing.
+ # you might want to specify some extra behavior here.
+ pass
+
+
+class DataLogger(object):
+ """
+ Logging class to log metrics to tensorboard and/or retrieve running statistics about logged data.
+ """
+ def __init__(self, log_dir, config, log_tb=True, log_wandb=False):
+ """
+ Args:
+ log_dir (str): base path to store logs
+ log_tb (bool): whether to use tensorboard logging
+ """
+ self._tb_logger = None
+ self._wandb_logger = None
+ self._data = dict() # store all the scalar data logged so far
+
+ if log_tb:
+ from tensorboardX import SummaryWriter
+ self._tb_logger = SummaryWriter(os.path.join(log_dir, 'tb'))
+
+ if log_wandb:
+ import wandb
+ import robomimic.macros as Macros
+
+ # set up wandb api key if specified in macros
+ if Macros.WANDB_API_KEY is not None:
+ os.environ["WANDB_API_KEY"] = Macros.WANDB_API_KEY
+
+ assert Macros.WANDB_ENTITY is not None, "WANDB_ENTITY macro is set to None." \
+ "\nSet this macro in {base_path}/macros_private.py" \
+ "\nIf this file does not exist, first run python {base_path}/scripts/setup_macros.py".format(base_path=robomimic.__path__[0])
+
+ # attempt to set up wandb 10 times. If unsuccessful after these trials, don't use wandb
+ num_attempts = 10
+ for attempt in range(num_attempts):
+ try:
+ # set up wandb
+ self._wandb_logger = wandb
+
+ self._wandb_logger.init(
+ entity=Macros.WANDB_ENTITY,
+ project=config.experiment.logging.wandb_proj_name,
+ name=config.experiment.name,
+ dir=log_dir,
+ mode=("offline" if attempt == num_attempts - 1 else "online"),
+ )
+
+ # set up info for identifying experiment
+ wandb_config = {k: v for (k, v) in config.meta.items() if k not in ["hp_keys", "hp_values"]}
+ for (k, v) in zip(config.meta["hp_keys"], config.meta["hp_values"]):
+ wandb_config[k] = v
+ if "algo" not in wandb_config:
+ wandb_config["algo"] = config.algo_name
+ self._wandb_logger.config.update(wandb_config)
+
+ break
+ except Exception as e:
+ log_warning("wandb initialization error (attempt #{}): {}".format(attempt + 1, e))
+ self._wandb_logger = None
+ time.sleep(30)
+
+ def record(self, k, v, epoch, data_type='scalar', log_stats=False):
+ """
+ Record data with logger.
+ Args:
+ k (str): key string
+ v (float or image): value to store
+ epoch: current epoch number
+ data_type (str): the type of data. either 'scalar' or 'image'
+ log_stats (bool): whether to store the mean/max/min/std for all data logged so far with key k
+ """
+
+ assert data_type in ['scalar', 'image']
+
+ if data_type == 'scalar':
+ # maybe update internal cache if logging stats for this key
+ if log_stats or k in self._data: # any key that we're logging or previously logged
+ if k not in self._data:
+ self._data[k] = []
+ self._data[k].append(v)
+
+ # maybe log to tensorboard
+ if self._tb_logger is not None:
+ if data_type == 'scalar':
+ self._tb_logger.add_scalar(k, v, epoch)
+ if log_stats:
+ stats = self.get_stats(k)
+ for (stat_k, stat_v) in stats.items():
+ stat_k_name = '{}-{}'.format(k, stat_k)
+ self._tb_logger.add_scalar(stat_k_name, stat_v, epoch)
+ elif data_type == 'image':
+ self._tb_logger.add_images(k, img_tensor=v, global_step=epoch, dataformats="NHWC")
+
+ if self._wandb_logger is not None:
+ try:
+ if data_type == 'scalar':
+ self._wandb_logger.log({k: v}, step=epoch)
+ if log_stats:
+ stats = self.get_stats(k)
+ for (stat_k, stat_v) in stats.items():
+ self._wandb_logger.log({"{}/{}".format(k, stat_k): stat_v}, step=epoch)
+ elif data_type == 'image':
+ raise NotImplementedError
+ except Exception as e:
+ log_warning("wandb logging: {}".format(e))
+
+ def get_stats(self, k):
+ """
+ Computes running statistics for a particular key.
+ Args:
+ k (str): key string
+ Returns:
+ stats (dict): dictionary of statistics
+ """
+ stats = dict()
+ stats['mean'] = np.mean(self._data[k])
+ stats['std'] = np.std(self._data[k])
+ stats['min'] = np.min(self._data[k])
+ stats['max'] = np.max(self._data[k])
+ return stats
+
+ def close(self):
+ """
+ Run before terminating to make sure all logs are flushed
+ """
+ if self._tb_logger is not None:
+ self._tb_logger.close()
+
+ if self._wandb_logger is not None:
+ self._wandb_logger.finish()
+
+
+class custom_tqdm(tqdm):
+ """
+ Small extension to tqdm to make a few changes from default behavior.
+ By default tqdm writes to stderr. Instead, we change it to write
+ to stdout.
+ """
+ def __init__(self, *args, **kwargs):
+ assert "file" not in kwargs
+ super(custom_tqdm, self).__init__(*args, file=sys.stdout, **kwargs)
+
+
+@contextmanager
+def silence_stdout():
+ """
+ This contextmanager will redirect stdout so that nothing is printed
+ to the terminal. Taken from the link below:
+
+ https://stackoverflow.com/questions/6735917/redirecting-stdout-to-nothing-in-python
+ """
+ old_target = sys.stdout
+ try:
+ with open(os.devnull, "w") as new_target:
+ sys.stdout = new_target
+ yield new_target
+ finally:
+ sys.stdout = old_target
+
+
+def log_warning(message, color="yellow", print_now=True):
+ """
+ This function logs a warning message by recording it in a global warning buffer.
+ The global registry will be maintained until @flush_warnings is called, at
+ which point the warnings will get printed to the terminal.
+
+ Args:
+ message (str): warning message to display
+ color (str): color of message - defaults to "yellow"
+ print_now (bool): if True (default), will print to terminal immediately, in
+ addition to adding it to the global warning buffer
+ """
+ global WARNINGS_BUFFER
+ buffer_message = colored("ROBOMIMIC WARNING(\n{}\n)".format(textwrap.indent(message, " ")), color)
+ WARNINGS_BUFFER.append(buffer_message)
+ if print_now:
+ print(buffer_message)
+
+
+def flush_warnings():
+ """
+ This function flushes all warnings from the global warning buffer to the terminal and
+ clears the global registry.
+ """
+ global WARNINGS_BUFFER
+ for msg in WARNINGS_BUFFER:
+ print(msg)
+ WARNINGS_BUFFER = []
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/loss_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/loss_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3f5bf223ed7dfbd510b4b8a8edf2e98b0567613
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/loss_utils.py
@@ -0,0 +1,208 @@
+"""
+This file contains a collection of useful loss functions for use with torch tensors.
+"""
+
+import math
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def cosine_loss(preds, labels):
+ """
+ Cosine loss between two tensors.
+
+ Args:
+ preds (torch.Tensor): torch tensor
+ labels (torch.Tensor): torch tensor
+
+ Returns:
+ loss (torch.Tensor): cosine loss
+ """
+ sim = torch.nn.CosineSimilarity(dim=len(preds.shape) - 1)(preds, labels)
+ return -torch.mean(sim - 1.0)
+
+
+def KLD_0_1_loss(mu, logvar):
+ """
+ KL divergence loss. Computes D_KL( N(mu, sigma) || N(0, 1) ). Note that
+ this function averages across the batch dimension, but sums across dimension.
+
+ Args:
+ mu (torch.Tensor): mean tensor of shape (B, D)
+ logvar (torch.Tensor): logvar tensor of shape (B, D)
+
+ Returns:
+ loss (torch.Tensor): KL divergence loss between the input gaussian distribution
+ and N(0, 1)
+ """
+ return -0.5 * (1. + logvar - mu.pow(2) - logvar.exp()).sum(dim=1).mean()
+
+
+def KLD_gaussian_loss(mu_1, logvar_1, mu_2, logvar_2):
+ """
+ KL divergence loss between two Gaussian distributions. This function
+ computes the average loss across the batch.
+
+ Args:
+ mu_1 (torch.Tensor): first means tensor of shape (B, D)
+ logvar_1 (torch.Tensor): first logvars tensor of shape (B, D)
+ mu_2 (torch.Tensor): second means tensor of shape (B, D)
+ logvar_2 (torch.Tensor): second logvars tensor of shape (B, D)
+
+ Returns:
+ loss (torch.Tensor): KL divergence loss between the two gaussian distributions
+ """
+ return -0.5 * (1. + \
+ logvar_1 - logvar_2 \
+ - ((mu_2 - mu_1).pow(2) / logvar_2.exp()) \
+ - (logvar_1.exp() / logvar_2.exp()) \
+ ).sum(dim=1).mean()
+
+
+def log_normal(x, m, v):
+ """
+ Log probability of tensor x under diagonal multivariate normal with
+ mean m and variance v. The last dimension of the tensors is treated
+ as the dimension of the Gaussian distribution - all other dimensions
+ are treated as independent Gaussians. Adapted from CS 236 at Stanford.
+
+ Args:
+ x (torch.Tensor): tensor with shape (B, ..., D)
+ m (torch.Tensor): means tensor with shape (B, ..., D) or (1, ..., D)
+ v (torch.Tensor): variances tensor with shape (B, ..., D) or (1, ..., D)
+
+ Returns:
+ log_prob (torch.Tensor): log probabilities of shape (B, ...)
+ """
+ element_wise = -0.5 * (torch.log(v) + (x - m).pow(2) / v + np.log(2 * np.pi))
+ log_prob = element_wise.sum(-1)
+ return log_prob
+
+
+def log_normal_mixture(x, m, v, w=None, log_w=None):
+ """
+ Log probability of tensor x under a uniform mixture of Gaussians.
+ Adapted from CS 236 at Stanford.
+
+ Args:
+ x (torch.Tensor): tensor with shape (B, D)
+ m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where
+ M is number of mixture components
+ v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where
+ M is number of mixture components
+ w (torch.Tensor): weights tensor - if provided, should be
+ shape (B, M) or (1, M)
+ log_w (torch.Tensor): log-weights tensor - if provided, should be
+ shape (B, M) or (1, M)
+
+ Returns:
+ log_prob (torch.Tensor): log probabilities of shape (B,)
+ """
+
+ # (B , D) -> (B , 1, D)
+ x = x.unsqueeze(1)
+ # (B, 1, D) -> (B, M, D) -> (B, M)
+ log_prob = log_normal(x, m, v)
+ if w is not None or log_w is not None:
+ # this weights the log probabilities by the mixture weights so we have log(w_i * N(x | m_i, v_i))
+ if w is not None:
+ assert log_w is None
+ log_w = torch.log(w)
+ log_prob += log_w
+ # then compute log sum_i exp [log(w_i * N(x | m_i, v_i))]
+ # (B, M) -> (B,)
+ log_prob = log_sum_exp(log_prob , dim=1)
+ else:
+ # (B, M) -> (B,)
+ log_prob = log_mean_exp(log_prob , dim=1) # mean accounts for uniform weights
+ return log_prob
+
+
+def log_mean_exp(x, dim):
+ """
+ Compute the log(mean(exp(x), dim)) in a numerically stable manner.
+ Adapted from CS 236 at Stanford.
+
+ Args:
+ x (torch.Tensor): a tensor
+ dim (int): dimension along which mean is computed
+
+ Returns:
+ y (torch.Tensor): log(mean(exp(x), dim))
+ """
+ return log_sum_exp(x, dim) - np.log(x.size(dim))
+
+
+def log_sum_exp(x, dim=0):
+ """
+ Compute the log(sum(exp(x), dim)) in a numerically stable manner.
+ Adapted from CS 236 at Stanford.
+
+ Args:
+ x (torch.Tensor): a tensor
+ dim (int): dimension along which sum is computed
+
+ Returns:
+ y (torch.Tensor): log(sum(exp(x), dim))
+ """
+ max_x = torch.max(x, dim)[0]
+ new_x = x - max_x.unsqueeze(dim).expand_as(x)
+ return max_x + (new_x.exp().sum(dim)).log()
+
+
+def project_values_onto_atoms(values, probabilities, atoms):
+ """
+ Project the categorical distribution given by @probabilities on the
+ grid of values given by @values onto a grid of values given by @atoms.
+ This is useful when computing a bellman backup where the backed up
+ values from the original grid will not be in the original support,
+ requiring L2 projection.
+
+ Each value in @values has a corresponding probability in @probabilities -
+ this probability mass is shifted to the closest neighboring grid points in
+ @atoms in proportion. For example, if the value in question is 0.2, and the
+ neighboring atoms are 0 and 1, then 0.8 of the probability weight goes to
+ atom 0 and 0.2 of the probability weight will go to 1.
+
+ Adapted from https://github.com/deepmind/acme/blob/master/acme/tf/losses/distributional.py#L42
+
+ Args:
+ values: value grid to project, of shape (batch_size, n_atoms)
+ probabilities: probabilities for categorical distribution on @values, shape (batch_size, n_atoms)
+ atoms: value grid to project onto, of shape (n_atoms,) or (1, n_atoms)
+
+ Returns:
+ new probability vectors that correspond to the L2 projection of the categorical distribution
+ onto @atoms
+ """
+
+ # make sure @atoms is shape (n_atoms,)
+ if len(atoms.shape) > 1:
+ atoms = atoms.squeeze(0)
+
+ # helper tensors from @atoms
+ vmin, vmax = atoms[0], atoms[1]
+ d_pos = torch.cat([atoms, vmin[None]], dim=0)[1:]
+ d_neg = torch.cat([vmax[None], atoms], dim=0)[:-1]
+
+ # ensure that @values grid is within the support of @atoms
+ clipped_values = values.clamp(min=vmin, max=vmax)[:, None, :] # (batch_size, 1, n_atoms)
+ clipped_atoms = atoms[None, :, None] # (1, n_atoms, 1)
+
+ # distance between atom values in support
+ d_pos = (d_pos - atoms)[None, :, None] # atoms[i + 1] - atoms[i], shape (1, n_atoms, 1)
+ d_neg = (atoms - d_neg)[None, :, None] # atoms[i] - atoms[i - 1], shape (1, n_atoms, 1)
+
+ # distances between all pairs of grid values
+ deltas = clipped_values - clipped_atoms # (batch_size, n_atoms, n_atoms)
+
+ # computes eqn (7) in distributional RL paper by doing the following - for each
+ # output atom in @atoms, consider values that are close enough, and weight their
+ # probability mass contribution by the normalized distance in [0, 1] given
+ # by (1. - (z_j - z_i) / (delta_z)).
+ d_sign = (deltas >= 0.).float()
+ delta_hat = (d_sign * deltas / d_pos) - ((1. - d_sign) * deltas / d_neg)
+ delta_hat = (1. - delta_hat).clamp(min=0., max=1.)
+ probabilities = probabilities[:, None, :]
+ return (delta_hat * probabilities).sum(dim=2)
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/obs_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/obs_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e1840de0419ca1523f066aad095481ea3516d95
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/obs_utils.py
@@ -0,0 +1,1025 @@
+"""
+A collection of utilities for working with observation dictionaries and
+different kinds of modalities such as images.
+"""
+import numpy as np
+from copy import deepcopy
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+
+import robomimic.utils.tensor_utils as TU
+
+# MACRO FOR VALID IMAGE CHANNEL SIZES
+VALID_IMAGE_CHANNEL_DIMS = {1, 3} # depth, rgb
+
+# DO NOT MODIFY THIS!
+# This keeps track of observation types (modalities) - and is populated on call to @initialize_obs_utils_with_obs_specs.
+# This will be a dictionary that maps observation modality (e.g. low_dim, rgb) to a list of observation
+# keys under that observation modality.
+OBS_MODALITIES_TO_KEYS = None
+
+# DO NOT MODIFY THIS!
+# This keeps track of observation types (modalities) - and is populated on call to @initialize_obs_utils_with_obs_specs.
+# This will be a dictionary that maps observation keys to their corresponding observation modality
+# (e.g. low_dim, rgb)
+OBS_KEYS_TO_MODALITIES = None
+
+# DO NOT MODIFY THIS
+# This holds the default encoder kwargs that will be used if none are passed at runtime for any given network
+DEFAULT_ENCODER_KWARGS = None
+
+# DO NOT MODIFY THIS
+# This holds the registered observation modality classes
+OBS_MODALITY_CLASSES = {}
+
+# DO NOT MODIFY THIS
+# This global dict stores mapping from observation encoder / randomizer network name to class.
+# We keep track of these registries to enable automated class inference at runtime, allowing
+# users to simply extend our base encoder / randomizer class and refer to that class in string form
+# in their config, without having to manually register their class internally.
+# This also future-proofs us for any additional encoder / randomizer classes we would
+# like to add ourselves.
+OBS_ENCODER_CORES = {"None": None} # Include default None
+OBS_RANDOMIZERS = {"None": None} # Include default None
+
+
+def register_obs_key(target_class):
+ assert target_class not in OBS_MODALITY_CLASSES, f"Already registered modality {target_class}!"
+ OBS_MODALITY_CLASSES[target_class.name] = target_class
+
+
+def register_encoder_core(target_class):
+ assert target_class not in OBS_ENCODER_CORES, f"Already registered obs encoder core {target_class}!"
+ OBS_ENCODER_CORES[target_class.__name__] = target_class
+
+
+def register_randomizer(target_class):
+ assert target_class not in OBS_RANDOMIZERS, f"Already registered obs randomizer {target_class}!"
+ OBS_RANDOMIZERS[target_class.__name__] = target_class
+
+
+class ObservationKeyToModalityDict(dict):
+ """
+ Custom dictionary class with the sole additional purpose of automatically registering new "keys" at runtime
+ without breaking. This is mainly for backwards compatibility, where certain keys such as "latent", "actions", etc.
+ are used automatically by certain models (e.g.: VAEs) but were never specified by the user externally in their
+ config. Thus, this dictionary will automatically handle those keys by implicitly associating them with the low_dim
+ modality.
+ """
+ def __getitem__(self, item):
+ # If a key doesn't already exist, warn the user and add default mapping
+ if item not in self.keys():
+ print(f"ObservationKeyToModalityDict: {item} not found,"
+ f" adding {item} to mapping with assumed low_dim modality!")
+ self.__setitem__(item, "low_dim")
+ return super(ObservationKeyToModalityDict, self).__getitem__(item)
+
+
+def obs_encoder_kwargs_from_config(obs_encoder_config):
+ """
+ Generate a set of args used to create visual backbones for networks
+ from the observation encoder config.
+
+ Args:
+ obs_encoder_config (Config): Config object containing relevant encoder information. Should be equivalent to
+ config.observation.encoder
+
+ Returns:
+ dict: Processed encoder kwargs
+ """
+ # Loop over each obs modality
+ # Unlock encoder config
+ obs_encoder_config.unlock()
+ for obs_modality, encoder_kwargs in obs_encoder_config.items():
+ # First run some sanity checks and store the classes
+ for cls_name, cores in zip(("core", "obs_randomizer"), (OBS_ENCODER_CORES, OBS_RANDOMIZERS)):
+ # Make sure the requested encoder for each obs_modality exists
+ cfg_cls = encoder_kwargs[f"{cls_name}_class"]
+ if cfg_cls is not None:
+ assert cfg_cls in cores, f"No {cls_name} class with name {cfg_cls} found, must register this class before" \
+ f"creating model!"
+ # encoder_kwargs[f"{cls_name}_class"] = cores[cfg_cls]
+
+ # Process core and randomizer kwargs
+ encoder_kwargs.core_kwargs = dict() if encoder_kwargs.core_kwargs is None else \
+ deepcopy(encoder_kwargs.core_kwargs)
+ encoder_kwargs.obs_randomizer_kwargs = dict() if encoder_kwargs.obs_randomizer_kwargs is None else \
+ deepcopy(encoder_kwargs.obs_randomizer_kwargs)
+
+ # Re-lock keys
+ obs_encoder_config.lock()
+
+ return dict(obs_encoder_config)
+
+
+def initialize_obs_modality_mapping_from_dict(modality_mapping):
+ """
+ This function is an alternative to @initialize_obs_utils_with_obs_specs, that allows manually setting of modalities.
+ NOTE: Only one of these should be called at runtime -- not both! (Note that all training scripts that use a config)
+ automatically handle obs modality mapping, so using this function is usually unnecessary)
+
+ Args:
+ modality_mapping (dict): Maps modality string names (e.g.: rgb, low_dim, etc.) to a list of observation
+ keys that should belong to that modality
+ """
+ global OBS_KEYS_TO_MODALITIES, OBS_MODALITIES_TO_KEYS
+
+ OBS_KEYS_TO_MODALITIES = ObservationKeyToModalityDict()
+ OBS_MODALITIES_TO_KEYS = dict()
+
+ for mod, keys in modality_mapping.items():
+ OBS_MODALITIES_TO_KEYS[mod] = deepcopy(keys)
+ OBS_KEYS_TO_MODALITIES.update({k: mod for k in keys})
+
+
+def initialize_obs_utils_with_obs_specs(obs_modality_specs):
+ """
+ This function should be called before using any observation key-specific
+ functions in this file, in order to make sure that all utility
+ functions are aware of the observation modalities (e.g. which ones
+ are low-dimensional, which ones are rgb, etc.).
+
+ It constructs two dictionaries: (1) that map observation modality (e.g. low_dim, rgb) to
+ a list of observation keys under that modality, and (2) that maps the inverse, specific
+ observation keys to their corresponding observation modality.
+
+ Input should be a nested dictionary (or list of such dicts) with the following structure:
+
+ obs_variant (str):
+ obs_modality (str): observation keys (list)
+ ...
+ ...
+
+ Example:
+ {
+ "obs": {
+ "low_dim": ["robot0_eef_pos", "robot0_eef_quat"],
+ "rgb": ["agentview_image", "robot0_eye_in_hand"],
+ }
+ "goal": {
+ "low_dim": ["robot0_eef_pos"],
+ "rgb": ["agentview_image"]
+ }
+ }
+
+ In the example, raw observations consist of low-dim and rgb modalities, with
+ the robot end effector pose under low-dim, and the agentview and wrist camera
+ images under rgb, while goal observations also consist of low-dim and rgb modalities,
+ with a subset of the raw observation keys per modality.
+
+ Args:
+ obs_modality_specs (dict or list): A nested dictionary (see docstring above for an example)
+ or a list of nested dictionaries. Accepting a list as input makes it convenient for
+ situations where multiple modules may each have their own modality spec.
+ """
+ global OBS_KEYS_TO_MODALITIES, OBS_MODALITIES_TO_KEYS
+
+ OBS_KEYS_TO_MODALITIES = ObservationKeyToModalityDict()
+
+ # accept one or more spec dictionaries - if it's just one, account for this
+ if isinstance(obs_modality_specs, dict):
+ obs_modality_spec_list = [obs_modality_specs]
+ else:
+ obs_modality_spec_list = obs_modality_specs
+
+ # iterates over observation specs
+ obs_modality_mapping = {}
+ for obs_modality_spec in obs_modality_spec_list:
+ # iterates over observation variants (e.g. observations, goals, subgoals)
+ for obs_modalities in obs_modality_spec.values():
+ for obs_modality, obs_keys in obs_modalities.items():
+ # add all keys for each obs modality to the corresponding list in obs_modality_mapping
+ if obs_modality not in obs_modality_mapping:
+ obs_modality_mapping[obs_modality] = []
+ obs_modality_mapping[obs_modality] += obs_keys
+ # loop over each modality, and add to global dict if it doesn't exist yet
+ for obs_key in obs_keys:
+ if obs_key not in OBS_KEYS_TO_MODALITIES:
+ OBS_KEYS_TO_MODALITIES[obs_key] = obs_modality
+ # otherwise, run sanity check to make sure we don't have conflicting, duplicate entries
+ else:
+ assert OBS_KEYS_TO_MODALITIES[obs_key] == obs_modality, \
+ f"Cannot register obs key {obs_key} with modality {obs_modality}; " \
+ f"already exists with corresponding modality {OBS_KEYS_TO_MODALITIES[obs_key]}"
+
+ # remove duplicate entries and store in global mapping
+ OBS_MODALITIES_TO_KEYS = { obs_modality : list(set(obs_modality_mapping[obs_modality])) for obs_modality in obs_modality_mapping }
+
+
+def initialize_default_obs_encoder(obs_encoder_config):
+ """
+ Initializes the default observation encoder kwarg information to be used by all networks if no values are manually
+ specified at runtime.
+
+ Args:
+ obs_encoder_config (Config): Observation encoder config to use.
+ Should be equivalent to config.observation.encoder
+ """
+ global DEFAULT_ENCODER_KWARGS
+ DEFAULT_ENCODER_KWARGS = obs_encoder_kwargs_from_config(obs_encoder_config)
+
+
+def initialize_obs_utils_with_config(config):
+ """
+ Utility function to parse config and call @initialize_obs_utils_with_obs_specs and
+ @initialize_default_obs_encoder_kwargs with the correct arguments.
+
+ Args:
+ config (BaseConfig instance): config object
+ """
+ if config.algo_name == "hbc":
+ obs_modality_specs = [
+ config.observation.planner.modalities,
+ config.observation.actor.modalities,
+ ]
+ obs_encoder_config = config.observation.actor.encoder
+ elif config.algo_name == "iris":
+ obs_modality_specs = [
+ config.observation.value_planner.planner.modalities,
+ config.observation.value_planner.value.modalities,
+ config.observation.actor.modalities,
+ ]
+ obs_encoder_config = config.observation.actor.encoder
+ else:
+ obs_modality_specs = [config.observation.modalities]
+ obs_encoder_config = config.observation.encoder
+ initialize_obs_utils_with_obs_specs(obs_modality_specs=obs_modality_specs)
+ initialize_default_obs_encoder(obs_encoder_config=obs_encoder_config)
+
+
+def key_is_obs_modality(key, obs_modality):
+ """
+ Check if observation key corresponds to modality @obs_modality.
+
+ Args:
+ key (str): obs key name to check
+ obs_modality (str): observation modality - e.g.: "low_dim", "rgb"
+ """
+ assert OBS_KEYS_TO_MODALITIES is not None, "error: must call ObsUtils.initialize_obs_utils_with_obs_config first"
+ return OBS_KEYS_TO_MODALITIES[key] == obs_modality
+
+
+def center_crop(im, t_h, t_w):
+ """
+ Takes a center crop of an image.
+
+ Args:
+ im (np.array or torch.Tensor): image of shape (..., height, width, channel)
+ t_h (int): height of crop
+ t_w (int): width of crop
+
+ Returns:
+ im (np.array or torch.Tensor): center cropped image
+ """
+ assert(im.shape[-3] >= t_h and im.shape[-2] >= t_w)
+ assert(im.shape[-1] in [1, 3])
+ crop_h = int((im.shape[-3] - t_h) / 2)
+ crop_w = int((im.shape[-2] - t_w) / 2)
+ return im[..., crop_h:crop_h + t_h, crop_w:crop_w + t_w, :]
+
+
+def batch_image_hwc_to_chw(im):
+ """
+ Channel swap for images - useful for preparing images for
+ torch training.
+
+ Args:
+ im (np.array or torch.Tensor): image of shape (batch, height, width, channel)
+ or (height, width, channel)
+
+ Returns:
+ im (np.array or torch.Tensor): image of shape (batch, channel, height, width)
+ or (channel, height, width)
+ """
+ start_dims = np.arange(len(im.shape) - 3).tolist()
+ s = start_dims[-1] if len(start_dims) > 0 else -1
+ if isinstance(im, np.ndarray):
+ return im.transpose(start_dims + [s + 3, s + 1, s + 2])
+ else:
+ return im.permute(start_dims + [s + 3, s + 1, s + 2])
+
+
+def batch_image_chw_to_hwc(im):
+ """
+ Inverse of channel swap in @batch_image_hwc_to_chw.
+
+ Args:
+ im (np.array or torch.Tensor): image of shape (batch, channel, height, width)
+ or (channel, height, width)
+
+ Returns:
+ im (np.array or torch.Tensor): image of shape (batch, height, width, channel)
+ or (height, width, channel)
+ """
+ start_dims = np.arange(len(im.shape) - 3).tolist()
+ s = start_dims[-1] if len(start_dims) > 0 else -1
+ if isinstance(im, np.ndarray):
+ return im.transpose(start_dims + [s + 2, s + 3, s + 1])
+ else:
+ return im.permute(start_dims + [s + 2, s + 3, s + 1])
+
+
+def process_obs(obs, obs_modality=None, obs_key=None):
+ """
+ Process observation @obs corresponding to @obs_modality modality (or implicitly inferred from @obs_key)
+ to prepare for network input.
+
+ Note that either obs_modality OR obs_key must be specified!
+
+ If both are specified, obs_key will override obs_modality
+
+ Args:
+ obs (np.array or torch.Tensor): Observation to process. Leading batch dimension is optional
+ obs_modality (str): Observation modality (e.g.: depth, image, low_dim, etc.)
+ obs_key (str): Name of observation from which to infer @obs_modality
+
+ Returns:
+ processed_obs (np.array or torch.Tensor): processed observation
+ """
+ assert obs_modality is not None or obs_key is not None, "Either obs_modality or obs_key must be specified!"
+ if obs_key is not None:
+ obs_modality = OBS_KEYS_TO_MODALITIES[obs_key]
+ return OBS_MODALITY_CLASSES[obs_modality].process_obs(obs)
+
+
+def process_obs_dict(obs_dict):
+ """
+ Process observations in observation dictionary to prepare for network input.
+
+ Args:
+ obs_dict (dict): dictionary mapping observation keys to np.array or
+ torch.Tensor. Leading batch dimensions are optional.
+
+ Returns:
+ new_dict (dict): dictionary where observation keys have been processed by their corresponding processors
+ """
+ return { k : process_obs(obs=obs, obs_key=k) for k, obs in obs_dict.items() } # shallow copy
+
+
+def process_frame(frame, channel_dim, scale):
+ """
+ Given frame fetched from dataset, process for network input. Converts array
+ to float (from uint8), normalizes pixels from range [0, @scale] to [0, 1], and channel swaps
+ from (H, W, C) to (C, H, W).
+
+ Args:
+ frame (np.array or torch.Tensor): frame array
+ channel_dim (int): Number of channels to sanity check for
+ scale (float or None): Value to normalize inputs by
+
+ Returns:
+ processed_frame (np.array or torch.Tensor): processed frame
+ """
+ # Channel size should either be 3 (RGB) or 1 (depth)
+ frame = TU.to_float(frame)
+ if scale is not None:
+ frame = frame / scale
+ frame = frame.clip(0.0, 1.0)
+ if frame.shape[-1] == 3 or frame.shape[-1] == 1:
+ frame = batch_image_hwc_to_chw(frame)
+
+ return frame
+
+
+def unprocess_obs(obs, obs_modality=None, obs_key=None):
+ """
+ Prepare observation @obs corresponding to @obs_modality modality (or implicitly inferred from @obs_key)
+ to prepare for deployment.
+
+ Note that either obs_modality OR obs_key must be specified!
+
+ If both are specified, obs_key will override obs_modality
+
+ Args:
+ obs (np.array or torch.Tensor): Observation to unprocess. Leading batch dimension is optional
+ obs_modality (str): Observation modality (e.g.: depth, image, low_dim, etc.)
+ obs_key (str): Name of observation from which to infer @obs_modality
+
+ Returns:
+ unprocessed_obs (np.array or torch.Tensor): unprocessed observation
+ """
+ assert obs_modality is not None or obs_key is not None, "Either obs_modality or obs_key must be specified!"
+ if obs_key is not None:
+ obs_modality = OBS_KEYS_TO_MODALITIES[obs_key]
+ return OBS_MODALITY_CLASSES[obs_modality].unprocess_obs(obs)
+
+
+def unprocess_obs_dict(obs_dict):
+ """
+ Prepare processed observation dictionary for saving to dataset. Inverse of
+ @process_obs.
+
+ Args:
+ obs_dict (dict): dictionary mapping observation keys to np.array or
+ torch.Tensor. Leading batch dimensions are optional.
+
+ Returns:
+ new_dict (dict): dictionary where observation keys have been unprocessed by
+ their respective unprocessor methods
+ """
+ return { k : unprocess_obs(obs=obs, obs_key=k) for k, obs in obs_dict.items() } # shallow copy
+
+
+def unprocess_frame(frame, channel_dim, scale):
+ """
+ Given frame prepared for network input, prepare for saving to dataset.
+ Inverse of @process_frame.
+
+ Args:
+ frame (np.array or torch.Tensor): frame array
+ channel_dim (int): What channel dimension should be (used for sanity check)
+ scale (float or None): Scaling factor to apply during denormalization
+
+ Returns:
+ unprocessed_frame (np.array or torch.Tensor): frame passed through
+ inverse operation of @process_frame
+ """
+ assert frame.shape[-3] == channel_dim # check for channel dimension
+ frame = batch_image_chw_to_hwc(frame)
+ if scale is not None:
+ frame = scale * frame
+ return frame
+
+
+def get_processed_shape(obs_modality, input_shape):
+ """
+ Given observation modality @obs_modality and expected inputs of shape @input_shape (excluding batch dimension), return the
+ expected processed observation shape resulting from process_{obs_modality}.
+
+ Args:
+ obs_modality (str): Observation modality to use (e.g.: low_dim, rgb, depth, etc...)
+ input_shape (list of int): Expected input dimensions, excluding the batch dimension
+
+ Returns:
+ list of int: expected processed input shape
+ """
+ return list(process_obs(obs=np.zeros(input_shape), obs_modality=obs_modality).shape)
+
+
+def normalize_dict(dict, normalization_stats):
+ """
+ Normalize dict using the provided "offset" and "scale" entries
+ for each observation key. The dictionary will be
+ modified in-place.
+
+ Args:
+ dict (dict): dictionary mapping key to np.array or
+ torch.Tensor. Leading batch dimensions are optional.
+
+ normalization_stats (dict): this should map keys to dicts
+ with a "offset" and "scale" of shape (1, ...) where ... is the default
+ shape for the dict value.
+
+ Returns:
+ dict (dict): obs dict with normalized arrays
+ """
+
+ # ensure we have statistics for each modality key in the dict
+ assert set(dict.keys()).issubset(normalization_stats)
+
+ for m in dict:
+ # get rid of extra dimension - we will pad for broadcasting later
+ offset = normalization_stats[m]["offset"][0]
+ scale = normalization_stats[m]["scale"][0]
+
+ # shape consistency checks
+ m_num_dims = len(offset.shape)
+ shape_len_diff = len(dict[m].shape) - m_num_dims
+ assert shape_len_diff >= 0, "shape length mismatch in @normalize_dict"
+ assert dict[m].shape[-m_num_dims:] == offset.shape, "shape mismatch in @normalize_dict"
+
+ # dict can have one or more leading batch dims - prepare for broadcasting
+ reshape_padding = tuple([1] * shape_len_diff)
+ offset = offset.reshape(reshape_padding + tuple(offset.shape))
+ scale = scale.reshape(reshape_padding + tuple(scale.shape))
+
+ dict[m] = (dict[m] - offset) / scale
+
+ return dict
+
+
+def unnormalize_dict(dict, normalization_stats):
+ """
+ Unnormalize dict using the provided "offset" and "scale" entries
+ for each observation key. The dictionary will be
+ modified in-place.
+
+ Args:
+ dict (dict): dictionary mapping key to np.array or
+ torch.Tensor. Leading batch dimensions are optional.
+
+ normalization_stats (dict): this should map keys to dicts
+ with a "offset" and "scale" of shape (1, ...) where ... is the default
+ shape for the dict value.
+
+ Returns:
+ dict (dict): obs dict with normalized arrays
+ """
+
+ # ensure we have statistics for each modality key in the dict
+ assert set(dict.keys()).issubset(normalization_stats)
+
+ for m in dict:
+ # get rid of extra dimension - we will pad for broadcasting later
+ offset = normalization_stats[m]["offset"][0]
+ scale = normalization_stats[m]["scale"][0]
+
+ # shape consistency checks
+ m_num_dims = len(offset.shape)
+ shape_len_diff = len(dict[m].shape) - m_num_dims
+ assert shape_len_diff >= 0, "shape length mismatch in @unnormalize_dict"
+ assert dict[m].shape[-m_num_dims:] == offset.shape, "shape mismatch in @unnormalize_dict"
+
+ # dict can have one or more leading batch dims - prepare for broadcasting
+ reshape_padding = tuple([1] * shape_len_diff)
+ offset = offset.reshape(reshape_padding + tuple(offset.shape))
+ scale = scale.reshape(reshape_padding + tuple(scale.shape))
+
+ dict[m] = (dict[m] * scale) + offset
+
+ return dict
+
+
+def has_modality(modality, obs_keys):
+ """
+ Returns True if @modality is present in the list of observation keys @obs_keys.
+
+ Args:
+ modality (str): modality to check for, e.g.: rgb, depth, etc.
+ obs_keys (list): list of observation keys
+ """
+ for k in obs_keys:
+ if key_is_obs_modality(k, obs_modality=modality):
+ return True
+ return False
+
+
+def repeat_and_stack_observation(obs_dict, n):
+ """
+ Given an observation dictionary and a desired repeat value @n,
+ this function will return a new observation dictionary where
+ each modality is repeated @n times and the copies are
+ stacked in the first dimension.
+
+ For example, if a batch of 3 observations comes in, and n is 2,
+ the output will look like [ob1; ob1; ob2; ob2; ob3; ob3] in
+ each modality.
+
+ Args:
+ obs_dict (dict): dictionary mapping observation key to np.array or
+ torch.Tensor. Leading batch dimensions are optional.
+
+ n (int): number to repeat by
+
+ Returns:
+ repeat_obs_dict (dict): repeated obs dict
+ """
+ return TU.repeat_by_expand_at(obs_dict, repeats=n, dim=0)
+
+
+def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
+ """
+ Crops images at the locations specified by @crop_indices. Crops will be
+ taken across all channels.
+
+ Args:
+ images (torch.Tensor): batch of images of shape [..., C, H, W]
+
+ crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
+ N is the number of crops to take per image and each entry corresponds
+ to the pixel height and width of where to take the crop. Note that
+ the indices can also be of shape [..., 2] if only 1 crop should
+ be taken per image. Leading dimensions must be consistent with
+ @images argument. Each index specifies the top left of the crop.
+ Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
+ H and W are the height and width of @images and CH and CW are
+ @crop_height and @crop_width.
+
+ crop_height (int): height of crop to take
+
+ crop_width (int): width of crop to take
+
+ Returns:
+ crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
+ """
+
+ # make sure length of input shapes is consistent
+ assert crop_indices.shape[-1] == 2
+ ndim_im_shape = len(images.shape)
+ ndim_indices_shape = len(crop_indices.shape)
+ assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2)
+
+ # maybe pad so that @crop_indices is shape [..., N, 2]
+ is_padded = False
+ if ndim_im_shape == ndim_indices_shape + 2:
+ crop_indices = crop_indices.unsqueeze(-2)
+ is_padded = True
+
+ # make sure leading dimensions between images and indices are consistent
+ assert images.shape[:-3] == crop_indices.shape[:-2]
+
+ device = images.device
+ image_c, image_h, image_w = images.shape[-3:]
+ num_crops = crop_indices.shape[-2]
+
+ # make sure @crop_indices are in valid range
+ assert (crop_indices[..., 0] >= 0).all().item()
+ assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
+ assert (crop_indices[..., 1] >= 0).all().item()
+ assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
+
+ # convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
+
+ # 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
+ crop_ind_grid_h = torch.arange(crop_height).to(device)
+ crop_ind_grid_h = TU.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1)
+ # 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
+ crop_ind_grid_w = torch.arange(crop_width).to(device)
+ crop_ind_grid_w = TU.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0)
+ # combine into shape [CH, CW, 2]
+ crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
+
+ # Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
+ # After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
+ # shape array that tells us which pixels from the corresponding source image to grab.
+ grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2]
+ all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape)
+
+ # For using @torch.gather, convert to flat indices from 2D indices, and also
+ # repeat across the channel dimension. To get flat index of each pixel to grab for
+ # each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
+ all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW]
+ all_crop_inds = TU.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW]
+ all_crop_inds = TU.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW]
+
+ # Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
+ images_to_crop = TU.unsqueeze_expand_at(images, size=num_crops, dim=-4)
+ images_to_crop = TU.flatten(images_to_crop, begin_axis=-2)
+ crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
+ # [..., N, C, CH * CW] -> [..., N, C, CH, CW]
+ reshape_axis = len(crops.shape) - 1
+ crops = TU.reshape_dimensions(crops, begin_axis=reshape_axis, end_axis=reshape_axis,
+ target_dims=(crop_height, crop_width))
+
+ if is_padded:
+ # undo padding -> [..., C, CH, CW]
+ crops = crops.squeeze(-4)
+ return crops
+
+
+def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False):
+ """
+ For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
+ @images.
+
+ Args:
+ images (torch.Tensor): batch of images of shape [..., C, H, W]
+
+ crop_height (int): height of crop to take
+
+ crop_width (int): width of crop to take
+
+ num_crops (n): number of crops to sample
+
+ pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
+ encoding of the original source pixel locations. This means that the
+ output crops will contain information about where in the source image
+ it was sampled from.
+
+ Returns:
+ crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
+ if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
+
+ crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
+ """
+ device = images.device
+
+ # maybe add 2 channels of spatial encoding to the source image
+ source_im = images
+ if pos_enc:
+ # spatial encoding [y, x] in [0, 1]
+ h, w = source_im.shape[-2:]
+ pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
+ pos_y = pos_y.float().to(device) / float(h)
+ pos_x = pos_x.float().to(device) / float(w)
+ position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
+
+ # unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
+ leading_shape = source_im.shape[:-3]
+ position_enc = position_enc[(None,) * len(leading_shape)]
+ position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
+
+ # concat across channel dimension with input
+ source_im = torch.cat((source_im, position_enc), dim=-3)
+
+ # make sure sample boundaries ensure crops are fully within the images
+ image_c, image_h, image_w = source_im.shape[-3:]
+ max_sample_h = image_h - crop_height
+ max_sample_w = image_w - crop_width
+
+ # Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
+ # Each gets @num_crops samples - typically this will just be the batch dimension (B), so
+ # we will sample [B, N] indices, but this supports having more than one leading dimension,
+ # or possibly no leading dimension.
+ #
+ # Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
+ crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
+ crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
+ crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2]
+
+ crops = crop_image_from_indices(
+ images=source_im,
+ crop_indices=crop_inds,
+ crop_height=crop_height,
+ crop_width=crop_width,
+ )
+
+ return crops, crop_inds
+
+
+class Modality:
+ """
+ Observation Modality class to encapsulate necessary functions needed to
+ process observations of this modality
+ """
+ # observation keys to associate with this modality
+ keys = set()
+
+ # Custom processing function that should prepare raw observations of this modality for training
+ _custom_obs_processor = None
+
+ # Custom unprocessing function that should prepare observations of this modality used during training for deployment
+ _custom_obs_unprocessor = None
+
+ # Name of this modality -- must be set by subclass!
+ name = None
+
+ def __init_subclass__(cls, **kwargs):
+ """
+ Hook method to automatically register all valid subclasses so we can keep track of valid modalities
+ """
+ assert cls.name is not None, f"Name of modality {cls.__name__} must be specified!"
+ register_obs_key(cls)
+
+ @classmethod
+ def set_keys(cls, keys):
+ """
+ Sets the observation keys associated with this modality.
+
+ Args:
+ keys (list or set): observation keys to associate with this modality
+ """
+ cls.keys = {k for k in keys}
+
+ @classmethod
+ def add_keys(cls, keys):
+ """
+ Adds the observation @keys associated with this modality to the current set of keys.
+
+ Args:
+ keys (list or set): observation keys to add to associate with this modality
+ """
+ for key in keys:
+ cls.keys.add(key)
+
+ @classmethod
+ def set_obs_processor(cls, processor=None):
+ """
+ Sets the processor for this observation modality. If @processor is set to None, then
+ the obs processor will use the default one (self.process_obs(...)). Otherwise, @processor
+ should be a function to process this corresponding observation modality.
+
+ Args:
+ processor (function or None): If not None, should be function that takes in either a
+ np.array or torch.Tensor and output the processed array / tensor. If None, will reset
+ to the default processor (self.process_obs(...))
+ """
+ cls._custom_obs_processor = processor
+
+ @classmethod
+ def set_obs_unprocessor(cls, unprocessor=None):
+ """
+ Sets the unprocessor for this observation modality. If @unprocessor is set to None, then
+ the obs unprocessor will use the default one (self.unprocess_obs(...)). Otherwise, @unprocessor
+ should be a function to process this corresponding observation modality.
+
+ Args:
+ unprocessor (function or None): If not None, should be function that takes in either a
+ np.array or torch.Tensor and output the unprocessed array / tensor. If None, will reset
+ to the default unprocessor (self.unprocess_obs(...))
+ """
+ cls._custom_obs_unprocessor = unprocessor
+
+ @classmethod
+ def _default_obs_processor(cls, obs):
+ """
+ Default processing function for this obs modality.
+
+ Note that this function is overridden by self.custom_obs_processor (a function with identical inputs / outputs)
+ if it is not None.
+
+ Args:
+ obs (np.array or torch.Tensor): raw observation, which may include a leading batch dimension
+
+ Returns:
+ np.array or torch.Tensor: processed observation
+ """
+ raise NotImplementedError
+
+ @classmethod
+ def _default_obs_unprocessor(cls, obs):
+ """
+ Default unprocessing function for this obs modality.
+
+ Note that this function is overridden by self.custom_obs_unprocessor
+ (a function with identical inputs / outputs) if it is not None.
+
+ Args:
+ obs (np.array or torch.Tensor): processed observation, which may include a leading batch dimension
+
+ Returns:
+ np.array or torch.Tensor: unprocessed observation
+ """
+ raise NotImplementedError
+
+ @classmethod
+ def process_obs(cls, obs):
+ """
+ Prepares an observation @obs of this modality for network input.
+
+ Args:
+ obs (np.array or torch.Tensor): raw observation, which may include a leading batch dimension
+
+ Returns:
+ np.array or torch.Tensor: processed observation
+ """
+ processor = cls._custom_obs_processor if \
+ cls._custom_obs_processor is not None else cls._default_obs_processor
+ return processor(obs)
+
+ @classmethod
+ def unprocess_obs(cls, obs):
+ """
+ Prepares an observation @obs of this modality for deployment.
+
+ Args:
+ obs (np.array or torch.Tensor): processed observation, which may include a leading batch dimension
+
+ Returns:
+ np.array or torch.Tensor: unprocessed observation
+ """
+ unprocessor = cls._custom_obs_unprocessor if \
+ cls._custom_obs_unprocessor is not None else cls._default_obs_unprocessor
+ return unprocessor(obs)
+
+ @classmethod
+ def process_obs_from_dict(cls, obs_dict, inplace=True):
+ """
+ Receives a dictionary of keyword mapped observations @obs_dict, and processes the observations with keys
+ corresponding to this modality. A copy will be made of the received dictionary unless @inplace is True
+
+ Args:
+ obs_dict (dict): Dictionary mapping observation keys to observations
+ inplace (bool): If True, will modify @obs_dict in place, otherwise, will create a copy
+
+ Returns:
+ dict: observation dictionary with processed observations corresponding to this modality
+ """
+ if inplace:
+ obs_dict = deepcopy(obs_dict)
+ # Loop over all keys and process the ones corresponding to this modality
+ for key, obs in obs_dict.values():
+ if key in cls.keys:
+ obs_dict[key] = cls.process_obs(obs)
+
+ return obs_dict
+
+
+class ImageModality(Modality):
+ """
+ Modality for RGB image observations
+ """
+ name = "rgb"
+
+ @classmethod
+ def _default_obs_processor(cls, obs):
+ """
+ Given image fetched from dataset, process for network input. Converts array
+ to float (from uint8), normalizes pixels from range [0, 255] to [0, 1], and channel swaps
+ from (H, W, C) to (C, H, W).
+
+ Args:
+ obs (np.array or torch.Tensor): image array
+
+ Returns:
+ processed_obs (np.array or torch.Tensor): processed image
+ """
+ return process_frame(frame=obs, channel_dim=3, scale=255.)
+
+ @classmethod
+ def _default_obs_unprocessor(cls, obs):
+ """
+ Given image prepared for network input, prepare for saving to dataset.
+ Inverse of @process_frame.
+
+ Args:
+ obs (np.array or torch.Tensor): image array
+
+ Returns:
+ unprocessed_obs (np.array or torch.Tensor): image passed through
+ inverse operation of @process_frame
+ """
+ return TU.to_uint8(unprocess_frame(frame=obs, channel_dim=3, scale=255.))
+
+
+class DepthModality(Modality):
+ """
+ Modality for depth observations
+ """
+ name = "depth"
+
+ @classmethod
+ def _default_obs_processor(cls, obs):
+ """
+ Given depth fetched from dataset, process for network input. Converts array
+ to float (from uint8), normalizes pixels from range [0, 1] to [0, 1], and channel swaps
+ from (H, W, C) to (C, H, W).
+
+ Args:
+ obs (np.array or torch.Tensor): depth array
+
+ Returns:
+ processed_obs (np.array or torch.Tensor): processed depth
+ """
+ return process_frame(frame=obs, channel_dim=1, scale=1.)
+
+ @classmethod
+ def _default_obs_unprocessor(cls, obs):
+ """
+ Given depth prepared for network input, prepare for saving to dataset.
+ Inverse of @process_depth.
+
+ Args:
+ obs (np.array or torch.Tensor): depth array
+
+ Returns:
+ unprocessed_obs (np.array or torch.Tensor): depth passed through
+ inverse operation of @process_depth
+ """
+ return unprocess_frame(frame=obs, channel_dim=1, scale=1.)
+
+
+class ScanModality(Modality):
+ """
+ Modality for scan observations
+ """
+ name = "scan"
+
+ @classmethod
+ def _default_obs_processor(cls, obs):
+ # Channel swaps ([...,] L, C) --> ([...,] C, L)
+
+ # First, add extra dimension at 2nd to last index to treat this as a frame
+ shape = obs.shape
+ new_shape = [*shape[:-2], 1, *shape[-2:]]
+ obs = obs.reshape(new_shape)
+
+ # Convert shape
+ obs = batch_image_hwc_to_chw(obs)
+
+ # Remove extra dimension (it's the second from last dimension)
+ obs = obs.squeeze(-2)
+ return obs
+
+ @classmethod
+ def _default_obs_unprocessor(cls, obs):
+ # Channel swaps ([B,] C, L) --> ([B,] L, C)
+
+ # First, add extra dimension at 1st index to treat this as a frame
+ shape = obs.shape
+ new_shape = [*shape[:-2], 1, *shape[-2:]]
+ obs = obs.reshape(new_shape)
+
+ # Convert shape
+ obs = batch_image_chw_to_hwc(obs)
+
+ # Remove extra dimension (it's the second from last dimension)
+ obs = obs.squeeze(-2)
+ return obs
+
+
+class LowDimModality(Modality):
+ """
+ Modality for low dimensional observations
+ """
+ name = "low_dim"
+
+ @classmethod
+ def _default_obs_processor(cls, obs):
+ return obs
+
+ @classmethod
+ def _default_obs_unprocessor(cls, obs):
+ return obs
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/python_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/python_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bc71bd1aaaf08bb406f3a72e83886d86c0d19a6
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/python_utils.py
@@ -0,0 +1,73 @@
+"""
+Set of general purpose utility functions for easier interfacing with Python API
+"""
+import inspect
+from copy import deepcopy
+import robomimic.macros as Macros
+
+
+def get_class_init_kwargs(cls):
+ """
+ Helper function to return a list of all valid keyword arguments (excluding "self") for the given @cls class.
+
+ Args:
+ cls (object): Class from which to grab __init__ kwargs
+
+ Returns:
+ list: All keyword arguments (excluding "self") specified by @cls __init__ constructor method
+ """
+ return list(inspect.signature(cls.__init__).parameters.keys())[1:]
+
+
+def extract_subset_dict(dic, keys, copy=False):
+ """
+ Helper function to extract a subset of dictionary key-values from a current dictionary. Optionally (deep)copies
+ the values extracted from the original @dic if @copy is True.
+
+ Args:
+ dic (dict): Dictionary containing multiple key-values
+ keys (Iterable): Specific keys to extract from @dic. If the key doesn't exist in @dic, then the key is skipped
+ copy (bool): If True, will deepcopy all values corresponding to the specified @keys
+
+ Returns:
+ dict: Extracted subset dictionary containing only the specified @keys and their corresponding values
+ """
+ subset = {k: dic[k] for k in keys if k in dic}
+ return deepcopy(subset) if copy else subset
+
+
+def extract_class_init_kwargs_from_dict(cls, dic, copy=False, verbose=False):
+ """
+ Helper function to return a dictionary of key-values that specifically correspond to @cls class's __init__
+ constructor method, from @dic which may or may not contain additional, irrelevant kwargs.
+
+ Note that @dic may possibly be missing certain kwargs as specified by cls.__init__. No error will be raised.
+
+ Args:
+ cls (object): Class from which to grab __init__ kwargs that will be be used as filtering keys for @dic
+ dic (dict): Dictionary containing multiple key-values
+ copy (bool): If True, will deepcopy all values corresponding to the specified @keys
+ verbose (bool): If True (or if macro DEBUG is True), then will print out mismatched keys
+
+ Returns:
+ dict: Extracted subset dictionary possibly containing only the specified keys from cls.__init__ and their
+ corresponding values
+ """
+ # extract only relevant kwargs for this specific backbone
+ cls_keys = get_class_init_kwargs(cls)
+ subdic = extract_subset_dict(
+ dic=dic,
+ keys=cls_keys,
+ copy=copy,
+ )
+
+ # Run sanity check if verbose or debugging
+ if verbose or Macros.DEBUG:
+ keys_not_in_cls = [k for k in dic if k not in cls_keys]
+ keys_not_in_dic = [k for k in cls_keys if k not in list(dic.keys())]
+ if len(keys_not_in_cls) > 0:
+ print(f"Warning: For class {cls.__name__}, got unknown keys: {keys_not_in_cls} ")
+ if len(keys_not_in_dic) > 0:
+ print(f"Warning: For class {cls.__name__}, got missing keys: {keys_not_in_dic} ")
+
+ return subdic
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/script_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/script_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42eaf627715b7d8ce30aaf43dfbfd992678beab
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/script_utils.py
@@ -0,0 +1,15 @@
+"""
+Collection of miscellaneous utility tools
+"""
+
+def deep_update(d, u):
+ """
+ Copied from https://stackoverflow.com/a/3233356
+ """
+ import collections
+ for k, v in u.items():
+ if isinstance(v, collections.abc.Mapping):
+ d[k] = deep_update(d.get(k, {}), v)
+ else:
+ d[k] = v
+ return d
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/tensor_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfdbbc5299e6f623c40c1e9b2b40bb7a2ac5dcb4
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/tensor_utils.py
@@ -0,0 +1,995 @@
+"""
+A collection of utilities for working with nested tensor structures consisting
+of numpy arrays and torch tensors.
+"""
+import collections
+import numpy as np
+import torch
+
+
+def recursive_dict_list_tuple_apply(x, type_func_dict, error_on_missing_type=True):
+ """
+ Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
+ {data_type: function_to_apply}.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ type_func_dict (dict): a mapping from data types to the functions to be
+ applied for each data type.
+ error_on_missing_type (bool): if True, raise an error if a type outside the @type_func_dict is
+ encountered, else, just return the same value (identity function)
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ assert(list not in type_func_dict)
+ assert(tuple not in type_func_dict)
+ assert(dict not in type_func_dict)
+
+ if isinstance(x, (dict, collections.OrderedDict)):
+ new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else dict()
+ for k, v in x.items():
+ new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict, error_on_missing_type)
+ return new_x
+ elif isinstance(x, (list, tuple)):
+ ret = [recursive_dict_list_tuple_apply(v, type_func_dict, error_on_missing_type) for v in x]
+ if isinstance(x, tuple):
+ ret = tuple(ret)
+ return ret
+ else:
+ for t, f in type_func_dict.items():
+ if isinstance(x, t):
+ return f(x)
+ else:
+ if error_on_missing_type:
+ raise NotImplementedError(
+ 'Cannot handle data type %s' % str(type(x)))
+ return x
+
+
+def map_tensor(x, func, error_on_missing_type=True):
+ """
+ Apply function @func to torch.Tensor objects in a nested dictionary or
+ list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ func (function): function to apply to each tensor
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: func,
+ type(None): lambda x: x,
+ },
+ error_on_missing_type=error_on_missing_type,
+ )
+
+
+def map_ndarray(x, func, error_on_missing_type=True):
+ """
+ Apply function @func to np.ndarray objects in a nested dictionary or
+ list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ func (function): function to apply to each array
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ np.ndarray: func,
+ type(None): lambda x: x,
+ },
+ error_on_missing_type=error_on_missing_type,
+ )
+
+
+def map_tensor_ndarray(x, tensor_func, ndarray_func, error_on_missing_type=True):
+ """
+ Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
+ np.ndarray objects in a nested dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ tensor_func (function): function to apply to each tensor
+ ndarray_Func (function): function to apply to each array
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: tensor_func,
+ np.ndarray: ndarray_func,
+ type(None): lambda x: x,
+ },
+ error_on_missing_type=error_on_missing_type,
+ )
+
+
+def clone(x):
+ """
+ Clones all torch tensors and numpy arrays in nested dictionary or list
+ or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.clone(),
+ np.ndarray: lambda x: x.copy(),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def detach(x):
+ """
+ Detaches all torch tensors in nested dictionary or list
+ or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.detach(),
+ }
+ )
+
+
+def to_batch(x):
+ """
+ Introduces a leading batch dimension of 1 for all torch tensors and numpy
+ arrays in nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x[None, ...],
+ np.ndarray: lambda x: x[None, ...],
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_sequence(x):
+ """
+ Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
+ arrays in nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x[:, None, ...],
+ np.ndarray: lambda x: x[:, None, ...],
+ type(None): lambda x: x,
+ }
+ )
+
+
+def index_at_time(x, ind):
+ """
+ Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
+ nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ ind (int): index
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x[:, ind, ...],
+ np.ndarray: lambda x: x[:, ind, ...],
+ type(None): lambda x: x,
+ }
+ )
+
+
+def unsqueeze(x, dim):
+ """
+ Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
+ in nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ dim (int): dimension
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.unsqueeze(dim=dim),
+ np.ndarray: lambda x: np.expand_dims(x, axis=dim),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def contiguous(x):
+ """
+ Makes all torch tensors and numpy arrays contiguous in nested dictionary or
+ list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.contiguous(),
+ np.ndarray: lambda x: np.ascontiguousarray(x),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_device(x, device):
+ """
+ Sends all torch tensors in nested dictionary or list or tuple to device
+ @device, and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ device (torch.Device): device to send tensors to
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x, d=device: x.to(d),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_tensor(x):
+ """
+ Converts all numpy arrays in nested dictionary or list or tuple to
+ torch tensors (and leaves existing torch Tensors as-is), and returns
+ a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x,
+ np.ndarray: lambda x: torch.from_numpy(x),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_numpy(x):
+ """
+ Converts all torch tensors in nested dictionary or list or tuple to
+ numpy (and leaves existing numpy arrays as-is), and returns
+ a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ def f(tensor):
+ if tensor.is_cuda:
+ return tensor.detach().cpu().numpy()
+ else:
+ return tensor.detach().numpy()
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: f,
+ np.ndarray: lambda x: x,
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_list(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to a list, and returns a new nested structure. Useful for
+ json encoding.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ def f(tensor):
+ if tensor.is_cuda:
+ return tensor.detach().cpu().numpy().tolist()
+ else:
+ return tensor.detach().numpy().tolist()
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: f,
+ np.ndarray: lambda x: x.tolist(),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_float(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to float type entries, and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.float(),
+ np.ndarray: lambda x: x.astype(np.float32),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_uint8(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to uint8 type entries, and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.byte(),
+ np.ndarray: lambda x: x.astype(np.uint8),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_uint16(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to uint16 type entries, and returns a new nested structure. Note
+ that torch does not support uint16, so int32 will be used (double storage).
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.int(),
+ np.ndarray: lambda x: x.astype(np.uint16),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def to_torch(x, device):
+ """
+ Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
+ torch tensors on device @device and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ device (torch.Device): device to send tensors to
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return to_device(to_float(to_tensor(x)), device)
+
+
+def to_one_hot_single(tensor, num_class):
+ """
+ Convert tensor to one-hot representation, assuming a certain number of total class labels.
+
+ Args:
+ tensor (torch.Tensor): tensor containing integer labels
+ num_class (int): number of classes
+
+ Returns:
+ x (torch.Tensor): tensor containing one-hot representation of labels
+ """
+ x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
+ x.scatter_(-1, tensor.unsqueeze(-1), 1)
+ return x
+
+
+def to_one_hot(tensor, num_class):
+ """
+ Convert all tensors in nested dictionary or list or tuple to one-hot representation,
+ assuming a certain number of total class labels.
+
+ Args:
+ tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
+ num_class (int): number of classes
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
+
+
+def flatten_single(x, begin_axis=1):
+ """
+ Flatten a tensor in all dimensions from @begin_axis onwards.
+
+ Args:
+ x (torch.Tensor): tensor to flatten
+ begin_axis (int): which axis to flatten from
+
+ Returns:
+ y (torch.Tensor): flattened tensor
+ """
+ fixed_size = x.size()[:begin_axis]
+ _s = list(fixed_size) + [-1]
+ return x.reshape(*_s)
+
+
+def flatten(x, begin_axis=1):
+ """
+ Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ begin_axis (int): which axis to flatten from
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
+ }
+ )
+
+
+def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
+ """
+ Reshape selected dimensions in a tensor to a target dimension.
+
+ Args:
+ x (torch.Tensor): tensor to reshape
+ begin_axis (int): begin dimension
+ end_axis (int): end dimension (inclusive)
+ target_dims (tuple or list): target shape for the range of dimensions
+ (@begin_axis, @end_axis)
+
+ Returns:
+ y (torch.Tensor): reshaped tensor
+ """
+ assert(begin_axis <= end_axis)
+ assert(begin_axis >= 0)
+ assert(end_axis < len(x.shape))
+ assert(isinstance(target_dims, (tuple, list)))
+ s = x.shape
+ final_s = []
+ for i in range(len(s)):
+ if i == begin_axis:
+ final_s.extend(target_dims)
+ elif i < begin_axis or i > end_axis:
+ final_s.append(s[i])
+ return x.reshape(*final_s)
+
+
+def reshape_dimensions(x, begin_axis, end_axis, target_dims):
+ """
+ Reshape selected dimensions for all tensors in nested dictionary or list or tuple
+ to a target dimension.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ begin_axis (int): begin dimension
+ end_axis (int): end dimension (inclusive)
+ target_dims (tuple or list): target shape for the range of dimensions
+ (@begin_axis, @end_axis)
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=t),
+ np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=t),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def join_dimensions(x, begin_axis, end_axis):
+ """
+ Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
+ all tensors in nested dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ begin_axis (int): begin dimension
+ end_axis (int): end dimension
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=[-1]),
+ np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=[-1]),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def expand_at_single(x, size, dim):
+ """
+ Expand a tensor at a single dimension @dim by @size
+
+ Args:
+ x (torch.Tensor): input tensor
+ size (int): size to expand
+ dim (int): dimension to expand
+
+ Returns:
+ y (torch.Tensor): expanded tensor
+ """
+ assert dim < x.ndimension()
+ assert x.shape[dim] == 1
+ expand_dims = [-1] * x.ndimension()
+ expand_dims[dim] = size
+ return x.expand(*expand_dims)
+
+
+def expand_at(x, size, dim):
+ """
+ Expand all tensors in nested dictionary or list or tuple at a single
+ dimension @dim by @size.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ size (int): size to expand
+ dim (int): dimension to expand
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
+
+
+def unsqueeze_expand_at(x, size, dim):
+ """
+ Unsqueeze and expand a tensor at a dimension @dim by @size.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ size (int): size to expand
+ dim (int): dimension to unsqueeze and expand
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ x = unsqueeze(x, dim)
+ return expand_at(x, size, dim)
+
+
+def repeat_by_expand_at(x, repeats, dim):
+ """
+ Repeat a dimension by combining expand and reshape operations.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ repeats (int): number of times to repeat the target dimension
+ dim (int): dimension to repeat on
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ x = unsqueeze_expand_at(x, repeats, dim + 1)
+ return join_dimensions(x, dim, dim + 1)
+
+
+def named_reduce_single(x, reduction, dim):
+ """
+ Reduce tensor at a dimension by named reduction functions.
+
+ Args:
+ x (torch.Tensor): tensor to be reduced
+ reduction (str): one of ["sum", "max", "mean", "flatten"]
+ dim (int): dimension to be reduced (or begin axis for flatten)
+
+ Returns:
+ y (torch.Tensor): reduced tensor
+ """
+ assert x.ndimension() > dim
+ assert reduction in ["sum", "max", "mean", "flatten"]
+ if reduction == "flatten":
+ x = flatten(x, begin_axis=dim)
+ elif reduction == "max":
+ x = torch.max(x, dim=dim)[0] # [B, D]
+ elif reduction == "sum":
+ x = torch.sum(x, dim=dim)
+ else:
+ x = torch.mean(x, dim=dim)
+ return x
+
+
+def named_reduce(x, reduction, dim):
+ """
+ Reduces all tensors in nested dictionary or list or tuple at a dimension
+ using a named reduction function.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ reduction (str): one of ["sum", "max", "mean", "flatten"]
+ dim (int): dimension to be reduced (or begin axis for flatten)
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
+
+
+def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
+ """
+ This function indexes out a target dimension of a tensor in a structured way,
+ by allowing a different value to be selected for each member of a flat index
+ tensor (@indices) corresponding to a source dimension. This can be interpreted
+ as moving along the source dimension, using the corresponding index value
+ in @indices to select values for all other dimensions outside of the
+ source and target dimensions. A common use case is to gather values
+ in target dimension 1 for each batch member (target dimension 0).
+
+ Args:
+ x (torch.Tensor): tensor to gather values for
+ target_dim (int): dimension to gather values along
+ source_dim (int): dimension to hold constant and use for gathering values
+ from the other dimensions
+ indices (torch.Tensor): flat index tensor with same shape as tensor @x along
+ @source_dim
+
+ Returns:
+ y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
+ """
+ assert len(indices.shape) == 1
+ assert x.shape[source_dim] == indices.shape[0]
+
+ # unsqueeze in all dimensions except the source dimension
+ new_shape = [1] * x.ndimension()
+ new_shape[source_dim] = -1
+ indices = indices.reshape(*new_shape)
+
+ # repeat in all dimensions - but preserve shape of source dimension,
+ # and make sure target_dimension has singleton dimension
+ expand_shape = list(x.shape)
+ expand_shape[source_dim] = -1
+ expand_shape[target_dim] = 1
+ indices = indices.expand(*expand_shape)
+
+ out = x.gather(dim=target_dim, index=indices)
+ return out.squeeze(target_dim)
+
+
+def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
+ """
+ Apply @gather_along_dim_with_dim_single to all tensors in a nested
+ dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ target_dim (int): dimension to gather values along
+ source_dim (int): dimension to hold constant and use for gathering values
+ from the other dimensions
+ indices (torch.Tensor): flat index tensor with same shape as tensor @x along
+ @source_dim
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(x,
+ lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i))
+
+
+def gather_sequence_single(seq, indices):
+ """
+ Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
+ the batch given an index for each sequence.
+
+ Args:
+ seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
+ indices (torch.Tensor): tensor indices of shape [B]
+
+ Return:
+ y (torch.Tensor): indexed tensor of shape [B, ....]
+ """
+ return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
+
+
+def gather_sequence(seq, indices):
+ """
+ Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
+ for tensors with leading dimensions [B, T, ...].
+
+ Args:
+ seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
+ of leading dimensions [B, T, ...]
+ indices (torch.Tensor): tensor indices of shape [B]
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
+ """
+ return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
+
+
+def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
+ """
+ Pad input tensor or array @seq in the time dimension (dimension 1).
+
+ Args:
+ seq (np.ndarray or torch.Tensor): sequence to be padded
+ padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
+ batched (bool): if sequence has the batch dimension
+ pad_same (bool): if pad by duplicating
+ pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
+
+ Returns:
+ padded sequence (np.ndarray or torch.Tensor)
+ """
+ assert isinstance(seq, (np.ndarray, torch.Tensor))
+ assert pad_same or (pad_values is not None)
+ if pad_values is not None:
+ assert isinstance(pad_values, float)
+ repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
+ concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
+ ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
+ seq_dim = 1 if batched else 0
+
+ begin_pad = []
+ end_pad = []
+
+ if padding[0] > 0:
+ if batched:
+ pad = seq[:, [0]] if pad_same else ones_like_func(seq[:, [0]]) * pad_values
+ else:
+ pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
+ begin_pad.append(repeat_func(pad, padding[0], seq_dim))
+ if padding[1] > 0:
+ if batched:
+ pad = seq[:, [-1]] if pad_same else ones_like_func(seq[:, [-1]]) * pad_values
+ else:
+ pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
+ end_pad.append(repeat_func(pad, padding[1], seq_dim))
+
+ return concat_func(begin_pad + [seq] + end_pad, seq_dim)
+
+
+def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
+ """
+ Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
+
+ Args:
+ seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
+ of leading dimensions [B, T, ...]
+ padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
+ batched (bool): if sequence has the batch dimension
+ pad_same (bool): if pad by duplicating
+ pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
+
+ Returns:
+ padded sequence (dict or list or tuple)
+ """
+ return recursive_dict_list_tuple_apply(
+ seq,
+ {
+ torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
+ pad_sequence_single(x, p, b, ps, pv),
+ np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
+ pad_sequence_single(x, p, b, ps, pv),
+ type(None): lambda x: x,
+ }
+ )
+
+
+def assert_size_at_dim_single(x, size, dim, msg):
+ """
+ Ensure that array or tensor @x has size @size in dim @dim.
+
+ Args:
+ x (np.ndarray or torch.Tensor): input array or tensor
+ size (int): size that tensors should have at @dim
+ dim (int): dimension to check
+ msg (str): text to display if assertion fails
+ """
+ assert x.shape[dim] == size, msg
+
+
+def assert_size_at_dim(x, size, dim, msg):
+ """
+ Ensure that arrays and tensors in nested dictionary or list or tuple have
+ size @size in dim @dim.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ size (int): size that tensors should have at @dim
+ dim (int): dimension to check
+ """
+ map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
+
+
+def get_shape(x):
+ """
+ Get all shapes of arrays and tensors in nested dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple that contains each array or
+ tensor's shape
+ """
+ return recursive_dict_list_tuple_apply(
+ x,
+ {
+ torch.Tensor: lambda x: x.shape,
+ np.ndarray: lambda x: x.shape,
+ type(None): lambda x: x,
+ }
+ )
+
+
+def list_of_flat_dict_to_dict_of_list(list_of_dict):
+ """
+ Helper function to go from a list of flat dictionaries to a dictionary of lists.
+ By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
+ floats, etc.
+
+ Args:
+ list_of_dict (list): list of flat dictionaries
+
+ Returns:
+ dict_of_list (dict): dictionary of lists
+ """
+ assert isinstance(list_of_dict, list)
+ dic = collections.OrderedDict()
+ for i in range(len(list_of_dict)):
+ for k in list_of_dict[i]:
+ if k not in dic:
+ dic[k] = []
+ dic[k].append(list_of_dict[i][k])
+ return dic
+
+
+def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):
+ """
+ Flatten a nested dict or list to a list.
+
+ For example, given a dict
+ {
+ a: 1
+ b: {
+ c: 2
+ }
+ c: 3
+ }
+
+ the function would return [(a, 1), (b_c, 2), (c, 3)]
+
+ Args:
+ d (dict, list): a nested dict or list to be flattened
+ parent_key (str): recursion helper
+ sep (str): separator for nesting keys
+ item_key (str): recursion helper
+ Returns:
+ list: a list of (key, value) tuples
+ """
+ items = []
+ if isinstance(d, (tuple, list)):
+ new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
+ for i, v in enumerate(d):
+ items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
+ return items
+ elif isinstance(d, dict):
+ new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
+ for k, v in d.items():
+ assert isinstance(k, str)
+ items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
+ return items
+ else:
+ new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
+ return [(new_key, d)]
+
+
+def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
+ """
+ Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
+ batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
+ Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
+ outputs to [B, T, ...].
+
+ Args:
+ inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
+ of leading dimensions [B, T, ...]
+ op: a layer op that accepts inputs
+ activation: activation to apply at the output
+ inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
+ inputs_as_args (bool) whether to feed input as a args list to the op
+ kwargs (dict): other kwargs to supply to the op
+
+ Returns:
+ outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
+ """
+ batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
+ inputs = join_dimensions(inputs, 0, 1)
+ if inputs_as_kwargs:
+ outputs = op(**inputs, **kwargs)
+ elif inputs_as_args:
+ outputs = op(*inputs, **kwargs)
+ else:
+ outputs = op(inputs, **kwargs)
+
+ if activation is not None:
+ outputs = map_tensor(outputs, activation)
+ outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
+ return outputs
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/test_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d70c0f073b7b22c0b3608ba5c4408a23904259a
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/test_utils.py
@@ -0,0 +1,270 @@
+"""
+Utilities for testing algorithm implementations - used mainly by scripts in tests directory.
+"""
+import os
+import json
+import shutil
+import traceback
+from termcolor import colored
+
+import numpy as np
+import torch
+
+import robomimic
+import robomimic.utils.file_utils as FileUtils
+import robomimic.utils.torch_utils as TorchUtils
+from robomimic.config import Config, config_factory
+from robomimic.scripts.train import train
+
+
+def maybe_remove_dir(dir_to_remove):
+ """
+ Remove directory if it exists.
+
+ Args:
+ dir_to_remove (str): path to directory to remove
+ """
+ if os.path.exists(dir_to_remove):
+ shutil.rmtree(dir_to_remove)
+
+
+def maybe_remove_file(file_to_remove):
+ """
+ Remove file if it exists.
+
+ Args:
+ file_to_remove (str): path to file to remove
+ """
+ if os.path.exists(file_to_remove):
+ os.remove(file_to_remove)
+
+
+def example_dataset_path():
+ """
+ Path to dataset to use for testing and example purposes. It should
+ exist under the tests/assets directory, and will be downloaded
+ from a server if it does not exist.
+ """
+ dataset_folder = os.path.join(robomimic.__path__[0], "../tests/assets/")
+ dataset_path = os.path.join(dataset_folder, "test_v141.hdf5")
+ if not os.path.exists(dataset_path):
+ print("\nWARNING: test hdf5 does not exist! Downloading from server...")
+ os.makedirs(dataset_folder, exist_ok=True)
+ FileUtils.download_url(
+ url="http://downloads.cs.stanford.edu/downloads/rt_benchmark/test_v141.hdf5",
+ download_dir=dataset_folder,
+ )
+ return dataset_path
+
+
+def example_momart_dataset_path():
+ """
+ Path to momart dataset to use for testing and example purposes. It should
+ exist under the tests/assets directory, and will be downloaded
+ from a server if it does not exist.
+ """
+ dataset_folder = os.path.join(robomimic.__path__[0], "../tests/assets/")
+ dataset_path = os.path.join(dataset_folder, "test_momart.hdf5")
+ if not os.path.exists(dataset_path):
+ user_response = input("\nWARNING: momart test hdf5 does not exist! We will download sample dataset. "
+ "This will take 0.6GB space. Proceed? y/n\n")
+ assert user_response.lower() in {"yes", "y"}, f"Did not receive confirmation. Aborting download."
+
+ print("\nDownloading from server...")
+
+ os.makedirs(dataset_folder, exist_ok=True)
+ FileUtils.download_url(
+ url="http://downloads.cs.stanford.edu/downloads/rt_mm/sample/test_momart.hdf5",
+ download_dir=dataset_folder,
+ )
+ return dataset_path
+
+
+def temp_model_dir_path():
+ """
+ Path to a temporary model directory to write to for testing and example purposes.
+ """
+ return os.path.join(robomimic.__path__[0], "../tests/tmp_model_dir")
+
+
+def temp_dataset_path():
+ """
+ Defines default dataset path to write to for testing.
+ """
+ return os.path.join(robomimic.__path__[0], "../tests/", "tmp.hdf5")
+
+
+def temp_video_path():
+ """
+ Defines default video path to write to for testing.
+ """
+ return os.path.join(robomimic.__path__[0], "../tests/", "tmp.mp4")
+
+
+def get_base_config(algo_name):
+ """
+ Base config for testing algorithms.
+
+ Args:
+ algo_name (str): name of algorithm - loads the corresponding json
+ from the config templates directory
+ """
+
+ # we will load and override defaults from template config
+ base_config_path = os.path.join(robomimic.__path__[0], "exps/templates/{}.json".format(algo_name))
+ with open(base_config_path, 'r') as f:
+ config = Config(json.load(f))
+
+ # small dataset with a handful of trajectories
+ config.train.data = example_dataset_path()
+
+ # temporary model dir
+ model_dir = temp_model_dir_path()
+ maybe_remove_dir(model_dir)
+ config.train.output_dir = model_dir
+
+ # train and validate for 3 gradient steps
+ config.experiment.name = "test"
+ config.experiment.validate = True
+ config.experiment.epoch_every_n_steps = 3
+ config.experiment.validation_epoch_every_n_steps = 3
+ config.train.num_epochs = 1
+
+ # default train and validation filter keys
+ config.train.hdf5_filter_key = "train"
+ config.train.hdf5_validation_filter_key = "valid"
+
+ # ensure model saving, rollout, and offscreen video rendering are tested too
+ config.experiment.save.enabled = True
+ config.experiment.save.every_n_epochs = 1
+ config.experiment.rollout.enabled = True
+ config.experiment.rollout.rate = 1
+ config.experiment.rollout.n = 1
+ config.experiment.rollout.horizon = 10
+ config.experiment.render_video = True
+
+ # turn off logging to stdout, since that can interfere with testing code outputs
+ config.experiment.logging.terminal_output_to_txt = False
+
+ # test cuda (if available)
+ config.train.cuda = True
+
+ return config
+
+
+def config_from_modifier(base_config, config_modifier):
+ """
+ Helper function to load a base config, modify it using
+ the passed @config modifier function, and finalize it
+ for training.
+
+ Args:
+ base_config (BaseConfig instance): starting config object that is
+ loaded (to change algorithm config defaults), and then modified
+ with @config_modifier
+
+ config_modifier (function): function that takes a config object as
+ input, and modifies it
+ """
+
+ # algo name to default config for this algorithm
+ algo_name = base_config["algo_name"]
+ config = config_factory(algo_name)
+
+ # update config with the settings specified in the base config
+ with config.unlocked():
+ config.update(base_config)
+
+ # modify the config and finalize it for training (no more modifications allowed)
+ config = config_modifier(config)
+
+ return config
+
+
+def checkpoint_path_from_test_run():
+ """
+ Helper function that gets the path of a model checkpoint after a test training run is finished.
+ """
+ exp_dir = os.path.join(temp_model_dir_path(), "test")
+ time_dir_names = [f.name for f in os.scandir(exp_dir) if f.is_dir()]
+ assert len(time_dir_names) == 1
+ path_to_models = os.path.join(exp_dir, time_dir_names[0], "models")
+ epoch_name = [f.name for f in os.scandir(path_to_models) if f.name.startswith("model")][0]
+ return os.path.join(path_to_models, epoch_name)
+
+
+def test_eval_agent_from_checkpoint(ckpt_path, device):
+ """
+ Test loading a model from checkpoint and running a rollout with the
+ trained agent for a small number of steps.
+
+ Args:
+ ckpt_path (str): path to a checkpoint pth file
+
+ device (torch.Device): torch device
+ """
+
+ # get policy and env from checkpoint
+ policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_path=ckpt_path, device=device, verbose=True)
+ env, _ = FileUtils.env_from_checkpoint(ckpt_dict=ckpt_dict, verbose=True)
+
+ # run a test rollout
+ ob_dict = env.reset()
+ policy.start_episode()
+ for _ in range(15):
+ ac = policy(ob=ob_dict)
+ ob_dict, r, done, _ = env.step(ac)
+
+
+def test_run(base_config, config_modifier):
+ """
+ Takes a base_config and config_modifier (function that modifies a passed Config object)
+ and runs training as a test. It also takes the trained checkpoint, tries to load the
+ policy and environment from the checkpoint, and run an evaluation rollout. Returns
+ a string that is colored green if the run finished successfully without any issues,
+ and colored red if an error occurred. If an error occurs, the traceback is included
+ in the string.
+
+ Args:
+ base_config (BaseConfig instance): starting config object that is
+ loaded (to change algorithm config defaults), and then modified
+ with @config_modifier
+
+ config_modifier (function): function that takes a config object as
+ input, and modifies it
+
+ Returns:
+ ret (str): a green "passed!" string, or a red "failed with error" string that contains
+ the traceback
+ """
+
+ # disable some macros for testing
+ Macros.RESULTS_SYNC_PATH = None
+ Macros.USE_MAGLEV = False
+ Macros.USE_NGC = False
+
+ try:
+ # get config
+ config = config_from_modifier(base_config=base_config, config_modifier=config_modifier)
+
+ # set torch device
+ device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda)
+
+ # run training
+ train(config, device=device)
+
+ # test evaluating a trained agent using saved checkpoint
+ ckpt_path = checkpoint_path_from_test_run()
+ test_eval_agent_from_checkpoint(ckpt_path, device=device)
+
+ # indicate success
+ ret = colored("passed!", "green")
+
+ except Exception as e:
+ # indicate failure by returning error string
+ ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red")
+
+ # make sure model directory is cleaned up before returning from this function
+ maybe_remove_dir(temp_model_dir_path())
+
+ return ret
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/torch_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..494dbddb7c37a05d8b7c1c66cf96aff8255655fa
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/torch_utils.py
@@ -0,0 +1,489 @@
+"""
+This file contains some PyTorch utilities.
+"""
+import numpy as np
+import torch
+import torch.optim as optim
+
+
+def soft_update(source, target, tau):
+ """
+ Soft update from the parameters of a @source torch module to a @target torch module
+ with strength @tau. The update follows target = target * (1 - tau) + source * tau.
+
+ Args:
+ source (torch.nn.Module): source network to push target network parameters towards
+ target (torch.nn.Module): target network to update
+ """
+ for target_param, param in zip(target.parameters(), source.parameters()):
+ target_param.copy_(
+ target_param * (1.0 - tau) + param * tau
+ )
+
+
+def hard_update(source, target):
+ """
+ Hard update @target parameters to match @source.
+
+ Args:
+ source (torch.nn.Module): source network to provide parameters
+ target (torch.nn.Module): target network to update parameters for
+ """
+ for target_param, param in zip(target.parameters(), source.parameters()):
+ target_param.copy_(param)
+
+
+def get_torch_device(try_to_use_cuda):
+ """
+ Return torch device. If using cuda (GPU), will also set cudnn.benchmark to True
+ to optimize CNNs.
+
+ Args:
+ try_to_use_cuda (bool): if True and cuda is available, will use GPU
+
+ Returns:
+ device (torch.Device): device to use for models
+ """
+ if try_to_use_cuda and torch.cuda.is_available():
+ torch.backends.cudnn.benchmark = True
+ device = torch.device("cuda:0")
+ else:
+ device = torch.device("cpu")
+ return device
+
+
+def reparameterize(mu, logvar):
+ """
+ Reparameterize for the backpropagation of z instead of q.
+ This makes it so that we can backpropagate through the sampling of z from
+ our encoder when feeding the sampled variable to the decoder.
+
+ (See "The reparameterization trick" section of https://arxiv.org/abs/1312.6114)
+
+ Args:
+ mu (torch.Tensor): batch of means from the encoder distribution
+ logvar (torch.Tensor): batch of log variances from the encoder distribution
+
+ Returns:
+ z (torch.Tensor): batch of sampled latents from the encoder distribution that
+ support backpropagation
+ """
+ # logvar = \log(\sigma^2) = 2 * \log(\sigma)
+ # \sigma = \exp(0.5 * logvar)
+
+ # clamped for numerical stability
+ logstd = (0.5 * logvar).clamp(-4, 15)
+ std = torch.exp(logstd)
+
+ # Sample \epsilon from normal distribution
+ # use std to create a new tensor, so we don't have to care
+ # about running on GPU or not
+ eps = std.new(std.size()).normal_()
+
+ # Then multiply with the standard deviation and add the mean
+ z = eps.mul(std).add_(mu)
+
+ return z
+
+
+def optimizer_from_optim_params(net_optim_params, net):
+ """
+ Helper function to return a torch Optimizer from the optim_params
+ section of the config for a particular network.
+
+ Args:
+ optim_params (Config): optim_params part of algo_config corresponding
+ to @net. This determines the optimizer that is created.
+
+ net (torch.nn.Module): module whose parameters this optimizer will be
+ responsible
+
+ Returns:
+ optimizer (torch.optim.Optimizer): optimizer
+ """
+ optimizer_type = net_optim_params.get("optimizer_type", "adam")
+ lr = net_optim_params["learning_rate"]["initial"]
+
+ if optimizer_type == "adam":
+ return optim.Adam(
+ params=net.parameters(),
+ lr=lr,
+ weight_decay=net_optim_params["regularization"]["L2"],
+ )
+ elif optimizer_type == "adamw":
+ return optim.AdamW(
+ params=net.parameters(),
+ lr=lr,
+ weight_decay=net_optim_params["regularization"]["L2"],
+ )
+
+
+def lr_scheduler_from_optim_params(net_optim_params, net, optimizer):
+ """
+ Helper function to return a LRScheduler from the optim_params
+ section of the config for a particular network. Returns None
+ if a scheduler is not needed.
+
+ Args:
+ optim_params (Config): optim_params part of algo_config corresponding
+ to @net. This determines whether a learning rate scheduler is created.
+
+ net (torch.nn.Module): module whose parameters this optimizer will be
+ responsible
+
+ optimizer (torch.optim.Optimizer): optimizer for this net
+
+ Returns:
+ lr_scheduler (torch.optim.lr_scheduler or None): learning rate scheduler
+ """
+ lr_scheduler_type = net_optim_params["learning_rate"].get("scheduler_type", "multistep")
+ epoch_schedule = net_optim_params["learning_rate"]["epoch_schedule"]
+
+ lr_scheduler = None
+ if len(epoch_schedule) > 0:
+ if lr_scheduler_type == "linear":
+ assert len(epoch_schedule) == 1
+ end_epoch = epoch_schedule[0]
+
+ return optim.lr_scheduler.LinearLR(
+ optimizer,
+ start_factor=1.0,
+ end_factor=net_optim_params["learning_rate"]["decay_factor"],
+ total_iters=end_epoch,
+ )
+ elif lr_scheduler_type == "multistep":
+ return optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer,
+ milestones=epoch_schedule,
+ gamma=net_optim_params["learning_rate"]["decay_factor"],
+ )
+ else:
+ raise ValueError("Invalid LR scheduler type: {}".format(lr_scheduler_type))
+
+ return lr_scheduler
+
+
+def backprop_for_loss(net, optim, loss, max_grad_norm=None, retain_graph=False):
+ """
+ Backpropagate loss and update parameters for network with
+ name @name.
+
+ Args:
+ net (torch.nn.Module): network to update
+
+ optim (torch.optim.Optimizer): optimizer to use
+
+ loss (torch.Tensor): loss to use for backpropagation
+
+ max_grad_norm (float): if provided, used to clip gradients
+
+ retain_graph (bool): if True, graph is not freed after backward call
+
+ Returns:
+ grad_norms (float): average gradient norms from backpropagation
+ """
+
+ # backprop
+ optim.zero_grad()
+ loss.backward(retain_graph=retain_graph)
+
+ # gradient clipping
+ if max_grad_norm is not None:
+ torch.nn.utils.clip_grad_norm_(net.parameters(), max_grad_norm)
+
+ # compute grad norms
+ grad_norms = 0.
+ for p in net.parameters():
+ # only clip gradients for parameters for which requires_grad is True
+ if p.grad is not None:
+ grad_norms += p.grad.data.norm(2).pow(2).item()
+
+ # step
+ optim.step()
+
+ return grad_norms
+
+
+def rot_6d_to_axis_angle(rot_6d):
+ """
+ Converts tensor with rot_6d representation to axis-angle representation.
+ """
+ rot_mat = rotation_6d_to_matrix(rot_6d)
+ rot = matrix_to_axis_angle(rot_mat)
+ return rot
+
+
+def axis_angle_to_rot_6d(axis_angle):
+ """
+ Converts tensor with rot_6d representation to axis-angle representation.
+ """
+ rot_mat = axis_angle_to_matrix(axis_angle)
+ rot_6d = matrix_to_rotation_6d(rot_mat)
+ return rot_6d
+
+
+class dummy_context_mgr():
+ """
+ A dummy context manager - useful for having conditional scopes (such
+ as @maybe_no_grad). Nothing happens in this scope.
+ """
+ def __enter__(self):
+ return None
+ def __exit__(self, exc_type, exc_value, traceback):
+ return False
+
+
+def maybe_no_grad(no_grad):
+ """
+ Args:
+ no_grad (bool): if True, the returned context will be torch.no_grad(), otherwise
+ it will be a dummy context
+ """
+ return torch.no_grad() if no_grad else dummy_context_mgr()
+
+
+"""
+The following utility functions were taken from PyTorch3D:
+https://github.com/facebookresearch/pytorch3d/blob/d84f274a0822da969668d00e831870fd88327845/pytorch3d/transforms/rotation_conversions.py
+"""
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+
+ return quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+
+
+def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+
+
+def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to axis/angle.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
+
+
+def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to quaternions.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = angles * 0.5
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
+ )
+ return quaternions
+
+
+def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to axis/angle.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
+ half_angles = torch.atan2(norms, quaternions[..., :1])
+ angles = 2 * half_angles
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ return quaternions[..., 1:] / sin_half_angles_over_angles
+
+
+def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1].
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
+ by dropping the last row. Note that 6D representation is not unique.
+ Args:
+ matrix: batch of rotation matrices of size (*, 3, 3)
+
+ Returns:
+ 6D rotation representation, of size (*, 6)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ batch_dim = matrix.size()[:-2]
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/train_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c86ef393194c362cdc6ef2adf511c92536c9aae
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/train_utils.py
@@ -0,0 +1,806 @@
+"""
+This file contains several utility functions used to define the main training loop. It
+mainly consists of functions to assist with logging, rollouts, and the @run_epoch function,
+which is the core training logic for models in this repository.
+"""
+import os
+import time
+import datetime
+import shutil
+import json
+import h5py
+import imageio
+import numpy as np
+from copy import deepcopy
+from collections import OrderedDict
+
+import torch
+
+import robomimic
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.log_utils as LogUtils
+import robomimic.utils.file_utils as FileUtils
+import robomimic.utils.obs_utils as ObsUtils
+import robomimic.utils.env_utils as EnvUtils
+import robomimic.macros as Macros
+
+
+from robomimic.utils.dataset import SequenceDataset, R2D2Dataset, MetaDataset
+from robomimic.envs.env_base import EnvBase
+from robomimic.envs.wrappers import EnvWrapper
+from robomimic.algo import RolloutPolicy
+
+
+def get_exp_dir(config, auto_remove_exp_dir=False):
+ """
+ Create experiment directory from config. If an identical experiment directory
+ exists and @auto_remove_exp_dir is False (default), the function will prompt
+ the user on whether to remove and replace it, or keep the existing one and
+ add a new subdirectory with the new timestamp for the current run.
+
+ Args:
+ auto_remove_exp_dir (bool): if True, automatically remove the existing experiment
+ folder if it exists at the same path.
+
+ Returns:
+ log_dir (str): path to created log directory (sub-folder in experiment directory)
+ output_dir (str): path to created models directory (sub-folder in experiment directory)
+ to store model checkpoints
+ video_dir (str): path to video directory (sub-folder in experiment directory)
+ to store rollout videos
+ """
+ assert not (Macros.USE_MAGLEV and Macros.USE_NGC)
+ if Macros.USE_MAGLEV or Macros.USE_NGC:
+ # remove existing experiment directory automatically if path exists so that we don't block on user input
+ auto_remove_exp_dir = True
+
+ # timestamp for directory names
+ t_now = time.time()
+ time_str = datetime.datetime.fromtimestamp(t_now).strftime('%Y%m%d%H%M%S')
+
+ # create directory for where to dump model parameters, tensorboard logs, and videos
+ base_output_dir = os.path.expandvars(os.path.expanduser(config.train.output_dir))
+ if not os.path.isabs(base_output_dir):
+ # relative paths are specified relative to robomimic module location
+ base_output_dir = os.path.join(robomimic.__path__[0], base_output_dir)
+ base_output_dir = os.path.join(base_output_dir, config.experiment.name)
+ if os.path.exists(base_output_dir):
+ if not auto_remove_exp_dir:
+ ans = input("WARNING: model directory ({}) already exists! \noverwrite? (y/n)\n".format(base_output_dir))
+ else:
+ ans = "y"
+ if ans == "y":
+ print("REMOVING")
+ shutil.rmtree(base_output_dir)
+
+ # only make model directory if model saving is enabled
+ output_dir = None
+ if config.experiment.save.enabled:
+ output_dir = os.path.join(base_output_dir, time_str, "models")
+ os.makedirs(output_dir)
+
+ # tensorboard directory
+ log_dir = os.path.join(base_output_dir, time_str, "logs")
+ os.makedirs(log_dir)
+
+ # video directory
+ video_dir = os.path.join(base_output_dir, time_str, "videos")
+ os.makedirs(video_dir)
+
+ # establish sync path for syncing important training results back
+ set_absolute_sync_path(
+ output_dir=config.train.output_dir,
+ exp_name=config.experiment.name,
+ time_str=time_str,
+ )
+
+ return log_dir, output_dir, video_dir
+
+
+def set_absolute_sync_path(output_dir, exp_name, time_str=None):
+ """
+ Establish sync path for syncing important training results back and puts the path
+ into Macros.RESULTS_SYNC_PATH_ABS
+ """
+ need_sync_results = (Macros.USE_MAGLEV and (Macros.MAGLEV_SCRATCH_SYNC_PATH is not None)) or \
+ (Macros.USE_NGC and (Macros.NGC_SCRATCH_SYNC_PATH is not None)) or \
+ ((not Macros.USE_MAGLEV) and (not Macros.USE_NGC) and (Macros.RESULTS_SYNC_PATH is not None))
+ if need_sync_results:
+ # get path where we will sync results
+ assert Macros.RESULTS_SYNC_PATH_ABS is None
+ base_output_dir_name = os.path.basename(os.path.normpath(os.path.expandvars(os.path.expanduser(output_dir))))
+
+ if Macros.USE_MAGLEV:
+ # turn relative scratch space path into absolute scratch space path
+ sync_prefix = os.path.join(
+ os.getenv("WORKFLOW_SCRATCH"),
+ "test_disk", # NOTE: most workflows mount scratch space under this prefix
+ Macros.MAGLEV_SCRATCH_SYNC_PATH,
+ )
+ elif Macros.USE_NGC:
+ sync_prefix = os.path.expandvars(os.path.expanduser(Macros.NGC_SCRATCH_SYNC_PATH))
+ else:
+ sync_prefix = os.path.expandvars(os.path.expanduser(Macros.RESULTS_SYNC_PATH))
+
+ # store at results_sync_path/output_dir_name/experiment_name/time_str
+ sync_path_without_time_dir = os.path.join(
+ sync_prefix,
+ base_output_dir_name,
+ exp_name,
+ )
+ if os.path.exists(sync_path_without_time_dir):
+ # only keep one time directory per exp name
+ shutil.rmtree(sync_path_without_time_dir)
+ Macros.RESULTS_SYNC_PATH_ABS = sync_path_without_time_dir
+ if time_str is not None:
+ Macros.RESULTS_SYNC_PATH_ABS = os.path.join(sync_path_without_time_dir, time_str)
+ os.makedirs(Macros.RESULTS_SYNC_PATH_ABS)
+ elif (Macros.USE_MAGLEV or Macros.USE_NGC):
+ LogUtils.log_warning(
+ "Using MagLev / NGC, but MAGLEV_SCRATCH_SYNC_PATH / NGC_SCRATCH_SYNC_PATH is unset in macros.py."
+ "No results will be synced back to scratch space."
+ )
+
+
+def load_data_for_training(config, obs_keys):
+ """
+ Data loading at the start of an algorithm.
+
+ Args:
+ config (BaseConfig instance): config object
+ obs_keys (list): list of observation modalities that are required for
+ training (this will inform the dataloader on what modalities to load)
+
+ Returns:
+ train_dataset (SequenceDataset instance): train dataset object
+ valid_dataset (SequenceDataset instance): valid dataset object (only if using validation)
+ """
+
+ # config can contain an attribute to filter on
+ train_filter_by_attribute = config.train.hdf5_filter_key
+ valid_filter_by_attribute = config.train.hdf5_validation_filter_key
+ if valid_filter_by_attribute is not None:
+ assert config.experiment.validate, "specified validation filter key {}, but config.experiment.validate is not set".format(valid_filter_by_attribute)
+
+ # load the dataset into memory
+ if config.experiment.validate:
+ assert not config.train.hdf5_normalize_obs, "no support for observation normalization with validation data yet"
+ assert (train_filter_by_attribute is not None) and (valid_filter_by_attribute is not None), \
+ "did not specify filter keys corresponding to train and valid split in dataset" \
+ " - please fill config.train.hdf5_filter_key and config.train.hdf5_validation_filter_key"
+ dataset_path = config.train.data if isinstance(config.train.data, str) else config.train.data[0]["path"]
+ train_demo_keys = FileUtils.get_demos_for_filter_key(
+ hdf5_path=os.path.expanduser(dataset_path),
+ filter_key=train_filter_by_attribute,
+ )
+ valid_demo_keys = FileUtils.get_demos_for_filter_key(
+ hdf5_path=os.path.expanduser(dataset_path),
+ filter_key=valid_filter_by_attribute,
+ )
+ assert set(train_demo_keys).isdisjoint(set(valid_demo_keys)), "training demonstrations overlap with " \
+ "validation demonstrations!"
+ train_dataset = dataset_factory(config, obs_keys, filter_by_attribute=train_filter_by_attribute)
+ valid_dataset = dataset_factory(config, obs_keys, filter_by_attribute=valid_filter_by_attribute)
+ else:
+ train_dataset = dataset_factory(config, obs_keys, filter_by_attribute=train_filter_by_attribute)
+ valid_dataset = None
+
+ return train_dataset, valid_dataset
+
+
+def dataset_factory(config, obs_keys, filter_by_attribute=None, dataset_path=None):
+ """
+ Create a SequenceDataset instance to pass to a torch DataLoader.
+
+ Args:
+ config (BaseConfig instance): config object
+
+ obs_keys (list): list of observation modalities that are required for
+ training (this will inform the dataloader on what modalities to load)
+
+ filter_by_attribute (str): if provided, use the provided filter key
+ to select a subset of demonstration trajectories to load
+
+ dataset_path (str): if provided, the SequenceDataset instance should load
+ data from this dataset path. Defaults to config.train.data.
+
+ Returns:
+ dataset (SequenceDataset instance): dataset object
+ """
+ if dataset_path is None:
+ dataset_path = config.train.data
+
+ ds_kwargs = dict(
+ # hdf5_path=dataset_path,
+ obs_keys=obs_keys,
+ action_keys=config.train.action_keys,
+ dataset_keys=config.train.dataset_keys,
+ action_config=config.train.action_config,
+ load_next_obs=config.train.hdf5_load_next_obs, # whether to load next observations (s') from dataset
+ frame_stack=config.train.frame_stack,
+ seq_length=config.train.seq_length,
+ pad_frame_stack=config.train.pad_frame_stack,
+ pad_seq_length=config.train.pad_seq_length,
+ get_pad_mask=False,
+ goal_mode=config.train.goal_mode,
+ hdf5_cache_mode=config.train.hdf5_cache_mode,
+ hdf5_use_swmr=config.train.hdf5_use_swmr,
+ hdf5_normalize_obs=config.train.hdf5_normalize_obs,
+ # filter_by_attribute=filter_by_attribute
+ )
+
+ if isinstance(dataset_path, str):
+ ds_kwargs["hdf5_path"] = [dataset_path]
+ ds_kwargs["filter_by_attribute"] = [filter_by_attribute]
+ ds_weights = [1.0]
+ ds_labels = ["dummy"]
+ else:
+ ds_kwargs["hdf5_path"] = [ds_cfg["path"] for ds_cfg in config.train.data]
+ ds_kwargs["filter_by_attribute"] = [filter_by_attribute for ds_cfg in config.train.data]
+ ds_weights = [ds_cfg.get("weight", 1.0) for ds_cfg in config.train.data]
+ ds_labels = [ds_cfg.get("label", "dummy") for ds_cfg in config.train.data]
+
+ meta_ds_kwargs = dict()
+
+ dataset = get_dataset(
+ ds_class=R2D2Dataset if config.train.data_format == "r2d2" else SequenceDataset,
+ ds_kwargs=ds_kwargs,
+ ds_weights=ds_weights,
+ ds_labels=ds_labels,
+ normalize_weights_by_ds_size=False,
+ meta_ds_class=MetaDataset,
+ meta_ds_kwargs=meta_ds_kwargs,
+ )
+
+ return dataset
+
+
+def get_dataset(
+ ds_class,
+ ds_kwargs,
+ ds_weights,
+ ds_labels,
+ normalize_weights_by_ds_size,
+ meta_ds_class=MetaDataset,
+ meta_ds_kwargs=None,
+):
+ ds_list = []
+ for i in range(len(ds_weights)):
+
+ ds_kwargs_copy = deepcopy(ds_kwargs)
+
+ keys = ["hdf5_path", "filter_by_attribute"]
+
+ for k in keys:
+ ds_kwargs_copy[k] = ds_kwargs[k][i]
+
+ ds_list.append(ds_class(**ds_kwargs_copy))
+
+ if len(ds_weights) == 1:
+ ds = ds_list[0]
+ else:
+ if meta_ds_kwargs is None:
+ meta_ds_kwargs = dict()
+ ds = meta_ds_class(
+ datasets=ds_list,
+ ds_weights=ds_weights,
+ ds_labels=ds_labels,
+ normalize_weights_by_ds_size=normalize_weights_by_ds_size,
+ **meta_ds_kwargs
+ )
+
+ return ds
+
+
+def run_rollout(
+ policy,
+ env,
+ horizon,
+ use_goals=False,
+ render=False,
+ video_writer=None,
+ video_skip=5,
+ terminate_on_success=False,
+ ):
+ """
+ Runs a rollout in an environment with the current network parameters.
+
+ Args:
+ policy (RolloutPolicy instance): policy to use for rollouts.
+
+ env (EnvBase instance): environment to use for rollouts.
+
+ horizon (int): maximum number of steps to roll the agent out for
+
+ use_goals (bool): if True, agent is goal-conditioned, so provide goal observations from env
+
+ render (bool): if True, render the rollout to the screen
+
+ video_writer (imageio Writer instance): if not None, use video writer object to append frames at
+ rate given by @video_skip
+
+ video_skip (int): how often to write video frame
+
+ terminate_on_success (bool): if True, terminate episode early as soon as a success is encountered
+
+ Returns:
+ results (dict): dictionary containing return, success rate, etc.
+ """
+ assert isinstance(policy, RolloutPolicy)
+ assert isinstance(env, EnvBase) or isinstance(env, EnvWrapper)
+
+ policy.start_episode()
+
+ ob_dict = env.reset()
+ goal_dict = None
+ if use_goals:
+ # retrieve goal from the environment
+ goal_dict = env.get_goal()
+
+ results = {}
+ video_count = 0 # video frame counter
+
+ total_reward = 0.
+ success = { k: False for k in env.is_success() } # success metrics
+ got_exception = False
+
+ try:
+ for step_i in range(horizon):
+
+ # get action from policy
+ ac = policy(ob=ob_dict, goal=goal_dict)
+
+ # play action
+ ob_dict, r, done, _ = env.step(ac)
+
+ # render to screen
+ if render:
+ env.render(mode="human")
+
+ # compute reward
+ total_reward += r
+
+ cur_success_metrics = env.is_success()
+ for k in success:
+ success[k] = success[k] or cur_success_metrics[k]
+
+ # visualization
+ if video_writer is not None:
+ if video_count % video_skip == 0:
+ video_img = env.render(mode="rgb_array", height=512, width=512)
+ video_writer.append_data(video_img)
+
+ video_count += 1
+
+ # break if done
+ if done or (terminate_on_success and success["task"]):
+ break
+
+ except env.rollout_exceptions as e:
+ print("WARNING: got rollout exception {}".format(e))
+ got_exception = True
+
+ results["Return"] = total_reward
+ results["Horizon"] = step_i + 1
+ results["Success_Rate"] = float(success["task"])
+ results["Exception_Rate"] = float(got_exception)
+
+ # log additional success metrics
+ for k in success:
+ if k != "task":
+ results["{}_Success_Rate".format(k)] = float(success[k])
+
+ return results
+
+
+def rollout_with_stats(
+ policy,
+ envs,
+ horizon,
+ use_goals=False,
+ num_episodes=None,
+ render=False,
+ video_dir=None,
+ video_path=None,
+ epoch=None,
+ video_skip=5,
+ terminate_on_success=False,
+ verbose=False,
+ ):
+ """
+ A helper function used in the train loop to conduct evaluation rollouts per environment
+ and summarize the results.
+
+ Can specify @video_dir (to dump a video per environment) or @video_path (to dump a single video
+ for all environments).
+
+ Args:
+ policy (RolloutPolicy instance): policy to use for rollouts.
+
+ envs (dict): dictionary that maps env_name (str) to EnvBase instance. The policy will
+ be rolled out in each env.
+
+ horizon (int): maximum number of steps to roll the agent out for
+
+ use_goals (bool): if True, agent is goal-conditioned, so provide goal observations from env
+
+ num_episodes (int): number of rollout episodes per environment
+
+ render (bool): if True, render the rollout to the screen
+
+ video_dir (str): if not None, dump rollout videos to this directory (one per environment)
+
+ video_path (str): if not None, dump a single rollout video for all environments
+
+ epoch (int): epoch number (used for video naming)
+
+ video_skip (int): how often to write video frame
+
+ terminate_on_success (bool): if True, terminate episode early as soon as a success is encountered
+
+ verbose (bool): if True, print results of each rollout
+
+ Returns:
+ all_rollout_logs (dict): dictionary of rollout statistics (e.g. return, success rate, ...)
+ averaged across all rollouts
+
+ video_paths (dict): path to rollout videos for each environment
+ """
+ assert isinstance(policy, RolloutPolicy)
+
+ all_rollout_logs = OrderedDict()
+
+ # handle paths and create writers for video writing
+ assert (video_path is None) or (video_dir is None), "rollout_with_stats: can't specify both video path and dir"
+ write_video = (video_path is not None) or (video_dir is not None)
+ video_paths = OrderedDict()
+ video_writers = OrderedDict()
+ if video_path is not None:
+ # a single video is written for all envs
+ video_paths = { k : video_path for k in envs }
+ video_writer = imageio.get_writer(video_path, fps=20)
+ video_writers = { k : video_writer for k in envs }
+ if video_dir is not None:
+ # video is written per env
+ video_str = "_epoch_{}.mp4".format(epoch) if epoch is not None else ".mp4"
+ video_paths = { k : os.path.join(video_dir, "{}{}".format(k, video_str)) for k in envs }
+ video_writers = { k : imageio.get_writer(video_paths[k], fps=20) for k in envs }
+
+ for env_name, env in envs.items():
+ env_video_writer = None
+ if write_video:
+ print("video writes to " + video_paths[env_name])
+ env_video_writer = video_writers[env_name]
+
+ print("rollout: env={}, horizon={}, use_goals={}, num_episodes={}".format(
+ env.name, horizon, use_goals, num_episodes,
+ ))
+ rollout_logs = []
+ iterator = range(num_episodes)
+ if not verbose:
+ iterator = LogUtils.custom_tqdm(iterator, total=num_episodes)
+
+ num_success = 0
+ for ep_i in iterator:
+ rollout_timestamp = time.time()
+ rollout_info = run_rollout(
+ policy=policy,
+ env=env,
+ horizon=horizon,
+ render=render,
+ use_goals=use_goals,
+ video_writer=env_video_writer,
+ video_skip=video_skip,
+ terminate_on_success=terminate_on_success,
+ )
+ rollout_info["time"] = time.time() - rollout_timestamp
+ rollout_logs.append(rollout_info)
+ num_success += rollout_info["Success_Rate"]
+ if verbose:
+ print("Episode {}, horizon={}, num_success={}".format(ep_i + 1, horizon, num_success))
+ print(json.dumps(rollout_info, sort_keys=True, indent=4))
+
+ if video_dir is not None:
+ # close this env's video writer (next env has it's own)
+ env_video_writer.close()
+
+ # average metric across all episodes
+ rollout_logs = dict((k, [rollout_logs[i][k] for i in range(len(rollout_logs))]) for k in rollout_logs[0])
+ rollout_logs_mean = dict((k, np.mean(v)) for k, v in rollout_logs.items())
+ rollout_logs_mean["Time_Episode"] = np.sum(rollout_logs["time"]) / 60. # total time taken for rollouts in minutes
+ all_rollout_logs[env_name] = rollout_logs_mean
+
+ if video_path is not None:
+ # close video writer that was used for all envs
+ video_writer.close()
+
+ return all_rollout_logs, video_paths
+
+
+def should_save_from_rollout_logs(
+ all_rollout_logs,
+ best_return,
+ best_success_rate,
+ epoch_ckpt_name,
+ save_on_best_rollout_return,
+ save_on_best_rollout_success_rate,
+ ):
+ """
+ Helper function used during training to determine whether checkpoints and videos
+ should be saved. It will modify input attributes appropriately (such as updating
+ the best returns and success rates seen and modifying the epoch ckpt name), and
+ returns a dict with the updated statistics.
+
+ Args:
+ all_rollout_logs (dict): dictionary of rollout results that should be consistent
+ with the output of @rollout_with_stats
+
+ best_return (dict): dictionary that stores the best average rollout return seen so far
+ during training, for each environment
+
+ best_success_rate (dict): dictionary that stores the best average success rate seen so far
+ during training, for each environment
+
+ epoch_ckpt_name (str): what to name the checkpoint file - this name might be modified
+ by this function
+
+ save_on_best_rollout_return (bool): if True, should save checkpoints that achieve a
+ new best rollout return
+
+ save_on_best_rollout_success_rate (bool): if True, should save checkpoints that achieve a
+ new best rollout success rate
+
+ Returns:
+ save_info (dict): dictionary that contains updated input attributes @best_return,
+ @best_success_rate, @epoch_ckpt_name, along with two additional attributes
+ @should_save_ckpt (True if should save this checkpoint), and @ckpt_reason
+ (string that contains the reason for saving the checkpoint)
+ """
+ should_save_ckpt = False
+ ckpt_reason = None
+ for env_name in all_rollout_logs:
+ rollout_logs = all_rollout_logs[env_name]
+
+ if rollout_logs["Return"] > best_return[env_name]:
+ best_return[env_name] = rollout_logs["Return"]
+ if save_on_best_rollout_return:
+ # save checkpoint if achieve new best return
+ epoch_ckpt_name += "_{}_return_{}".format(env_name, best_return[env_name])
+ should_save_ckpt = True
+ ckpt_reason = "return"
+
+ if rollout_logs["Success_Rate"] > best_success_rate[env_name]:
+ best_success_rate[env_name] = rollout_logs["Success_Rate"]
+ if save_on_best_rollout_success_rate:
+ # save checkpoint if achieve new best success rate
+ epoch_ckpt_name += "_{}_success_{}".format(env_name, best_success_rate[env_name])
+ should_save_ckpt = True
+ ckpt_reason = "success"
+
+ # return the modified input attributes
+ return dict(
+ best_return=best_return,
+ best_success_rate=best_success_rate,
+ epoch_ckpt_name=epoch_ckpt_name,
+ should_save_ckpt=should_save_ckpt,
+ ckpt_reason=ckpt_reason,
+ )
+
+
+def save_model(model, config, env_meta, shape_meta, ckpt_path, obs_normalization_stats=None, action_normalization_stats=None):
+ """
+ Save model to a torch pth file.
+
+ Args:
+ model (Algo instance): model to save
+
+ config (BaseConfig instance): config to save
+
+ env_meta (dict): env metadata for this training run
+
+ shape_meta (dict): shape metdata for this training run
+
+ ckpt_path (str): writes model checkpoint to this path
+
+ obs_normalization_stats (dict): optionally pass a dictionary for observation
+ normalization. This should map observation keys to dicts
+ with a "mean" and "std" of shape (1, ...) where ... is the default
+ shape for the observation.
+
+ action_normalization_stats (dict): TODO
+ """
+ env_meta = deepcopy(env_meta)
+ shape_meta = deepcopy(shape_meta)
+ params = dict(
+ model=model.serialize(),
+ config=config.dump(),
+ algo_name=config.algo_name,
+ env_metadata=env_meta,
+ shape_metadata=shape_meta,
+ )
+ if obs_normalization_stats is not None:
+ assert config.train.hdf5_normalize_obs
+ obs_normalization_stats = deepcopy(obs_normalization_stats)
+ params["obs_normalization_stats"] = TensorUtils.to_list(obs_normalization_stats)
+ if action_normalization_stats is not None:
+ action_normalization_stats = deepcopy(action_normalization_stats)
+ params["action_normalization_stats"] = TensorUtils.to_list(action_normalization_stats)
+ torch.save(params, ckpt_path)
+ print("save checkpoint to {}".format(ckpt_path))
+
+
+def run_epoch(model, data_loader, epoch, validate=False, num_steps=None, obs_normalization_stats=None):
+ """
+ Run an epoch of training or validation.
+
+ Args:
+ model (Algo instance): model to train
+
+ data_loader (DataLoader instance): data loader that will be used to serve batches of data
+ to the model
+
+ epoch (int): epoch number
+
+ validate (bool): whether this is a training epoch or validation epoch. This tells the model
+ whether to do gradient steps or purely do forward passes.
+
+ num_steps (int): if provided, this epoch lasts for a fixed number of batches (gradient steps),
+ otherwise the epoch is a complete pass through the training dataset
+
+ obs_normalization_stats (dict or None): if provided, this should map observation keys to dicts
+ with a "mean" and "std" of shape (1, ...) where ... is the default
+ shape for the observation.
+
+ Returns:
+ step_log_all (dict): dictionary of logged training metrics averaged across all batches
+ """
+ epoch_timestamp = time.time()
+ if validate:
+ model.set_eval()
+ else:
+ model.set_train()
+ if num_steps is None:
+ num_steps = len(data_loader)
+
+ step_log_all = []
+ timing_stats = dict(Data_Loading=[], Process_Batch=[], Train_Batch=[], Log_Info=[])
+ start_time = time.time()
+
+ data_loader_iter = iter(data_loader)
+ for _ in LogUtils.custom_tqdm(range(num_steps)):
+
+ # load next batch from data loader
+ try:
+ t = time.time()
+ batch = next(data_loader_iter)
+ except StopIteration:
+ # reset for next dataset pass
+ data_loader_iter = iter(data_loader)
+ t = time.time()
+ batch = next(data_loader_iter)
+ timing_stats["Data_Loading"].append(time.time() - t)
+
+ # process batch for training
+ t = time.time()
+ input_batch = model.process_batch_for_training(batch)
+ input_batch = model.postprocess_batch_for_training(input_batch, obs_normalization_stats=obs_normalization_stats)
+ timing_stats["Process_Batch"].append(time.time() - t)
+
+ # forward and backward pass
+ t = time.time()
+ info = model.train_on_batch(input_batch, epoch, validate=validate)
+ timing_stats["Train_Batch"].append(time.time() - t)
+
+ # tensorboard logging
+ t = time.time()
+ step_log = model.log_info(info)
+ step_log_all.append(step_log)
+ timing_stats["Log_Info"].append(time.time() - t)
+
+ # flatten and take the mean of the metrics
+ step_log_dict = {}
+ for i in range(len(step_log_all)):
+ for k in step_log_all[i]:
+ if k not in step_log_dict:
+ step_log_dict[k] = []
+ step_log_dict[k].append(step_log_all[i][k])
+ step_log_all = dict((k, float(np.mean(v))) for k, v in step_log_dict.items())
+
+ # add in timing stats
+ for k in timing_stats:
+ # sum across all training steps, and convert from seconds to minutes
+ step_log_all["Time_{}".format(k)] = np.sum(timing_stats[k]) / 60.
+ step_log_all["Time_Epoch"] = (time.time() - epoch_timestamp) / 60.
+
+ return step_log_all
+
+
+def is_every_n_steps(interval, current_step, skip_zero=False):
+ """
+ Convenient function to check whether current_step is at the interval.
+ Returns True if current_step % interval == 0 and asserts a few corner cases (e.g., interval <= 0)
+
+ Args:
+ interval (int): target interval
+ current_step (int): current step
+ skip_zero (bool): whether to skip 0 (return False at 0)
+
+ Returns:
+ is_at_interval (bool): whether current_step is at the interval
+ """
+ if interval is None:
+ return False
+ assert isinstance(interval, int) and interval > 0
+ assert isinstance(current_step, int) and current_step >= 0
+ if skip_zero and current_step == 0:
+ return False
+ return current_step % interval == 0
+
+
+def get_model_from_output_folder(models_path, videos_path=None, epoch=None, best=False, last=False):
+ """
+ Gets path to model (and video) for a certain epoch number (or the best or last epoch).
+
+ Args:
+ models_path (str): path to models folder (in output directory)
+ videos_path (str): path to videos folder (in output directory)
+ epoch (int): if provided, get model ckpt and video for this epoch
+ best (bool): if True, get the model and video for the best checkpoint (according to success rate)
+ last (bool): if True, get the model and video for the last checkpoint (according to epoch number)
+
+ Returns:
+ model_path (str): path to model pth
+ video_path (str): path to mp4
+ epoch (int): epoch number for retrieved model and video paths
+ """
+
+ # make sure we either grab a specific epoch, best epoch, or last epoch
+ assert sum([(epoch is not None), best, last]) == 1
+
+ # run through models to find the epoch we want
+ best_success_rate = -0.1
+ need_particular_epoch = (epoch is not None)
+ need_best_epoch = best
+ need_max_epoch = last
+
+ selected_epoch = -1
+ selected_model_path = None
+ for f in os.scandir(models_path):
+ model_epoch = int(f.name.split("_")[2].strip(".pth"))
+
+ if need_particular_epoch and (model_epoch == epoch):
+ selected_epoch = epoch
+ selected_model_path = os.path.join(models_path, f.name)
+
+ elif need_best_epoch:
+ # this block assumes that the experiment run opted to save the model with the best checkpoint
+ if "success" in f.name:
+ # example name: model_epoch_250_NutAssemblySquareTarget_6_success_0.86.pth
+ # take last piece - "0.86.pth" -> "0.86" -> convert to float
+ success_rate = float(f.name.split("success_")[-1][:-4])
+ if success_rate > best_success_rate:
+ best_success_rate = success_rate
+ selected_epoch = model_epoch
+ selected_model_path = os.path.join(models_path, f.name)
+
+ elif need_max_epoch:
+ # find last epoch
+ if model_epoch > selected_epoch:
+ selected_epoch = model_epoch
+ selected_model_path = os.path.join(models_path, f.name)
+
+ assert selected_epoch != -1
+ assert selected_model_path is not None
+
+ selected_video_path = None
+ if videos_path is not None:
+ # get random video filename
+ video_fname = None
+ for f in os.scandir(videos_path):
+ video_fname = f.name
+ break
+ # example video file name: NutAssemblySquareTarget_6_epoch_150.mp4
+ # take name skeleton and use it to infer name of source videos we want, then copy them
+ video_name_prefix = video_fname.split("epoch")[0]
+ selected_video_path = os.path.join(videos_path, "{}epoch_{}.mp4".format(video_name_prefix, selected_epoch))
+ return selected_model_path, selected_video_path, selected_epoch
diff --git a/phantom/submodules/phantom-robomimic/robomimic/utils/vis_utils.py b/phantom/submodules/phantom-robomimic/robomimic/utils/vis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c73d7a1ee504db2a73e62c28e6aaf11a0f714b
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/robomimic/utils/vis_utils.py
@@ -0,0 +1,111 @@
+"""
+This file contains utility functions for visualizing image observations in the training pipeline.
+These functions can be a useful debugging tool.
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.cm as cm
+
+import robomimic.utils.tensor_utils as TensorUtils
+import robomimic.utils.obs_utils as ObsUtils
+
+
+def image_tensor_to_numpy(image):
+ """
+ Converts processed image tensors to numpy so that they can be saved to disk or video.
+ A useful utility function for visualizing images in the middle of training.
+
+ Args:
+ image (torch.Tensor): images of shape [..., C, H, W]
+
+ Returns:
+ image (np.array): converted images of shape [..., H, W, C] and type uint8
+ """
+ return TensorUtils.to_numpy(
+ ObsUtils.unprocess_image(image)
+ ).astype(np.uint8)
+
+
+def image_to_disk(image, fname):
+ """
+ Writes an image to disk.
+
+ Args:
+ image (np.array): image of shape [H, W, 3]
+ fname (str): path to save image to
+ """
+ image = Image.fromarray(image)
+ image.save(fname)
+
+
+def image_tensor_to_disk(image, fname):
+ """
+ Writes an image tensor to disk. Any leading batch dimensions are indexed out
+ with the first element.
+
+ Args:
+ image (torch.Tensor): image of shape [..., C, H, W]. All leading dimensions
+ will be indexed out with the first element
+ fname (str): path to save image to
+ """
+ # index out all leading dimensions before [C, H, W]
+ num_leading_dims = len(image.shape[:-3])
+ for _ in range(num_leading_dims):
+ image = image[0]
+ image = image_tensor_to_numpy(image)
+ image_to_disk(image, fname)
+
+
+def visualize_image_randomizer(original_image, randomized_image, randomizer_name=None):
+ """
+ A function that visualizes the before and after of an image-based input randomizer
+ Args:
+ original_image: batch of original image shaped [B, H, W, 3]
+ randomized_image: randomized image shaped [B, N, H, W, 3]. N is the number of randomization per input sample
+ randomizer_name: (Optional) name of the randomizer
+ Returns:
+ None
+ """
+
+ B, N, H, W, C = randomized_image.shape
+
+ # Create a grid of subplots with B rows and N+1 columns (1 for the original image, N for the randomized images)
+ fig, axes = plt.subplots(B, N + 1, figsize=(4 * (N + 1), 4 * B))
+
+ for i in range(B):
+ # Display the original image in the first column of each row
+ axes[i, 0].imshow(original_image[i])
+ axes[i, 0].set_title("Original")
+ axes[i, 0].axis("off")
+
+ # Display the randomized images in the remaining columns of each row
+ for j in range(N):
+ axes[i, j + 1].imshow(randomized_image[i, j])
+ axes[i, j + 1].axis("off")
+
+ title = randomizer_name if randomizer_name is not None else "Randomized"
+ fig.suptitle(title, fontsize=16)
+
+ # Adjust the space between subplots for better visualization
+ plt.subplots_adjust(wspace=0.5, hspace=0.5)
+
+ # Show the entire grid of subplots
+ plt.show()
+
+
+def depth_to_rgb(depth_map, depth_min=None, depth_max=None):
+ """
+ Convert depth map to rgb array by computing normalized depth values in [0, 1].
+ """
+ # normalize depth map into [0, 1]
+ if depth_min is None:
+ depth_min = depth_map.min()
+ if depth_max is None:
+ depth_max = depth_map.max()
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
+ # depth_map = np.clip(depth_map / 3., 0., 1.)
+ if len(depth_map.shape) == 3:
+ assert depth_map.shape[-1] == 1
+ depth_map = depth_map[..., 0]
+ assert len(depth_map.shape) == 2 # [H, W]
+ return (255. * cm.hot(depth_map, 3)).astype(np.uint8)[..., :3]
diff --git a/phantom/submodules/phantom-robomimic/setup.py b/phantom/submodules/phantom-robomimic/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..91eca6ca948f03ee52383e697ce14b00b39856e9
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/setup.py
@@ -0,0 +1,44 @@
+from setuptools import setup, find_packages
+
+# read the contents of your README file
+from os import path
+this_directory = path.abspath(path.dirname(__file__))
+with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
+ lines = f.readlines()
+
+# remove images from README
+lines = [x for x in lines if (('.png' not in x) and ('.gif' not in x))]
+long_description = ''.join(lines)
+
+setup(
+ name="robomimic",
+ packages=[
+ package for package in find_packages() if package.startswith("robomimic")
+ ],
+ install_requires=[
+ "numpy>=1.13.3",
+ "h5py",
+ "psutil",
+ "tqdm",
+ "termcolor",
+ "tensorboard",
+ "tensorboardX",
+ "imageio",
+ "imageio-ffmpeg",
+ "matplotlib",
+ "egl_probe>=1.0.1",
+ # "torch",
+ # "torchvision",
+ "diffusers>=0.26.2",
+ ],
+ eager_resources=['*'],
+ include_package_data=True,
+ python_requires='>=3',
+ description="robomimic: A Modular Framework for Robot Learning from Demonstration",
+ author="Ajay Mandlekar, Danfei Xu, Josiah Wong, Soroush Nasiriany, Chen Wang, Matthew Bronars",
+ url="https://github.com/ARISE-Initiative/robomimic",
+ author_email="amandlek@cs.stanford.edu",
+ version="0.3.0",
+ long_description=long_description,
+ long_description_content_type='text/markdown'
+)
diff --git a/phantom/submodules/phantom-robomimic/tests/test.sh b/phantom/submodules/phantom-robomimic/tests/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e25e77c134753f486bfa3231fc3abe7bef11754c
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/tests/test.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+echo "running tests for bc..."
+python test_bc.py
+echo "running tests for hbc..."
+python test_hbc.py
+echo "running tests for iris..."
+python test_iris.py
+echo "running tests for bcq..."
+python test_bcq.py
+echo "running tests for cql..."
+python test_cql.py
+echo "running tests for scripts..."
+python test_scripts.py
+echo "running tests for examples..."
+python test_examples.py
diff --git a/phantom/submodules/phantom-robomimic/tests/test_bc.py b/phantom/submodules/phantom-robomimic/tests/test_bc.py
new file mode 100644
index 0000000000000000000000000000000000000000..adc125014bafb920aa45d69fc7d48798d3231c6f
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/tests/test_bc.py
@@ -0,0 +1,295 @@
+"""
+Test script for BC algorithms. Each test trains a variant of BC
+for a handful of gradient steps and tries one rollout with
+the model. Excludes stdout output by default (pass --verbose
+to see stdout output).
+"""
+import argparse
+from collections import OrderedDict
+
+import robomimic
+from robomimic.config import Config
+import robomimic.utils.test_utils as TestUtils
+from robomimic.utils.log_utils import silence_stdout
+from robomimic.utils.torch_utils import dummy_context_mgr
+
+
+def get_algo_base_config():
+ """
+ Base config for testing BC algorithms.
+ """
+
+ # config with basic settings for quick training run
+ config = TestUtils.get_base_config(algo_name="bc")
+
+ # low-level obs (note that we define it here because @observation structure might vary per algorithm,
+ # for example HBC)
+ config.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.modalities.obs.rgb = []
+
+ # by default, vanilla BC
+ config.algo.gaussian.enabled = False
+ config.algo.gmm.enabled = False
+ config.algo.vae.enabled = False
+ config.algo.rnn.enabled = False
+
+ return config
+
+
+def convert_config_for_images(config):
+ """
+ Modify config to use image observations.
+ """
+
+ # using high-dimensional images - don't load entire dataset into memory, and smaller batch size
+ config.train.hdf5_cache_mode = "low_dim"
+ config.train.num_data_workers = 0
+ config.train.batch_size = 16
+
+ # replace object with rgb modality
+ config.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"]
+ config.observation.modalities.obs.rgb = ["agentview_image"]
+
+ # set up visual encoders
+ config.observation.encoder.rgb.core_class = "VisualCore"
+ config.observation.encoder.rgb.core_kwargs.feature_dimension = 64
+ config.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv' # ResNet backbone for image observations (unused if no image observations)
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False # kwargs for visual core
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False
+ config.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax" # Alternate options are "SpatialMeanPool" or None (no pooling)
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0
+
+ # observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization
+ config.observation.encoder.rgb.obs_randomizer_class = None
+
+ return config
+
+
+def make_image_modifier(config_modifier):
+ """
+ Turn a config modifier into its image version. Note that
+ this explicit function definition is needed for proper
+ scoping of @config_modifier.
+ """
+ return lambda x: config_modifier(convert_config_for_images(x))
+
+
+# mapping from test name to config modifier functions
+MODIFIERS = OrderedDict()
+def register_mod(test_name):
+ def decorator(config_modifier):
+ MODIFIERS[test_name] = config_modifier
+ return decorator
+
+
+@register_mod("bc")
+def bc_modifier(config):
+ # no-op
+ return config
+
+
+@register_mod("bc-gaussian")
+def bc_gaussian_modifier(config):
+ config.algo.gaussian.enabled = True
+ return config
+
+
+@register_mod("bc-gmm")
+def bc_gmm_modifier(config):
+ config.algo.gmm.enabled = True
+ return config
+
+
+@register_mod("bc-vae, N(0, 1) prior")
+def bc_vae_modifier_1(config):
+ # N(0, 1) prior
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = False
+ config.algo.vae.prior.is_conditioned = False
+ return config
+
+
+@register_mod("bc-vae, Gaussian prior (obs-independent)")
+def bc_vae_modifier_2(config):
+ # learn parameters of Gaussian prior (obs-independent)
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = True
+ config.algo.vae.prior.is_conditioned = False
+ config.algo.vae.prior.use_gmm = False
+ config.algo.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bc-vae, Gaussian prior (obs-dependent)")
+def bc_vae_modifier_3(config):
+ # learn parameters of Gaussian prior (obs-dependent)
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = True
+ config.algo.vae.prior.is_conditioned = True
+ config.algo.vae.prior.use_gmm = False
+ config.algo.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bc-vae, GMM prior (obs-independent, weights-fixed)")
+def bc_vae_modifier_4(config):
+ # learn parameters of GMM prior (obs-independent, weights-fixed)
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = True
+ config.algo.vae.prior.is_conditioned = False
+ config.algo.vae.prior.use_gmm = True
+ config.algo.vae.prior.gmm_learn_weights = False
+ config.algo.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bc-vae, GMM prior (obs-independent, weights-learned)")
+def bc_vae_modifier_5(config):
+ # learn parameters of GMM prior (obs-independent, weights-learned)
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = True
+ config.algo.vae.prior.is_conditioned = False
+ config.algo.vae.prior.use_gmm = True
+ config.algo.vae.prior.gmm_learn_weights = True
+ config.algo.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bc-vae, GMM prior (obs-dependent, weights-fixed)")
+def bc_vae_modifier_6(config):
+ # learn parameters of GMM prior (obs-dependent, weights-fixed)
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = True
+ config.algo.vae.prior.is_conditioned = True
+ config.algo.vae.prior.use_gmm = True
+ config.algo.vae.prior.gmm_learn_weights = False
+ config.algo.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bc-vae, GMM prior (obs-dependent, weights-learned)")
+def bc_vae_modifier_7(config):
+ # learn parameters of GMM prior (obs-dependent, weights-learned)
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = True
+ config.algo.vae.prior.is_conditioned = True
+ config.algo.vae.prior.use_gmm = True
+ config.algo.vae.prior.gmm_learn_weights = True
+ config.algo.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bc-vae, uniform categorical prior")
+def bc_vae_modifier_8(config):
+ # uniform categorical prior
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = False
+ config.algo.vae.prior.is_conditioned = False
+ config.algo.vae.prior.use_gmm = False
+ config.algo.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("bc-vae, categorical prior (obs-independent)")
+def bc_vae_modifier_9(config):
+ # learn parameters of categorical prior (obs-independent)
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = True
+ config.algo.vae.prior.is_conditioned = False
+ config.algo.vae.prior.use_gmm = False
+ config.algo.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("bc-vae, categorical prior (obs-dependent)")
+def bc_vae_modifier_10(config):
+ # learn parameters of categorical prior (obs-dependent)
+ config.algo.vae.enabled = True
+ config.algo.vae.prior.learn = True
+ config.algo.vae.prior.is_conditioned = True
+ config.algo.vae.prior.use_gmm = False
+ config.algo.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("bc-rnn")
+def bc_rnn_modifier(config):
+ config.algo.rnn.enabled = True
+ config.algo.rnn.horizon = 10
+ config.train.seq_length = 10
+ return config
+
+
+@register_mod("bc-rnn-gmm")
+def bc_rnn_gmm_modifier(config):
+ config.algo.gmm.enabled = True
+ config.algo.rnn.enabled = True
+ config.algo.rnn.horizon = 10
+ config.train.seq_length = 10
+ return config
+
+
+@register_mod("bc-transformer")
+def bc_transformer_modifier(config):
+ config.algo.transformer.enabled = True
+ config.train.frame_stack = 10
+ config.train.seq_length = 1
+ return config
+
+
+@register_mod("bc-transformer-gmm")
+def bc_transformer_gmm_modifier(config):
+ config.algo.gmm.enabled = True
+ config.algo.transformer.enabled = True
+ config.train.frame_stack = 10
+ config.train.seq_length = 1
+ return config
+
+
+# add image version of all tests
+image_modifiers = OrderedDict()
+for test_name in MODIFIERS:
+ lst = test_name.split("-")
+ name = "-".join(lst[:1] + ["rgb"] + lst[1:])
+ image_modifiers[name] = make_image_modifier(MODIFIERS[test_name])
+MODIFIERS.update(image_modifiers)
+
+
+# test for image crop randomization
+@register_mod("bc-image-crop")
+def bc_image_crop_modifier(config):
+ config = convert_config_for_images(config)
+
+ # observation randomizer class - using Crop randomizer
+ config.observation.encoder.rgb.obs_randomizer_class = "CropRandomizer"
+
+ # kwargs for observation randomizers (for the CropRandomizer, this is size and number of crops)
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.num_crops = 1
+ config.observation.encoder.rgb.obs_randomizer_kwargs.pos_enc = False
+ return config
+
+
+def test_bc(silence=True):
+ for test_name in MODIFIERS:
+ context = silence_stdout() if silence else dummy_context_mgr()
+ with context:
+ base_config = get_algo_base_config()
+ res_str = TestUtils.test_run(base_config=base_config, config_modifier=MODIFIERS[test_name])
+ print("{}: {}".format(test_name, res_str))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--verbose",
+ action='store_true',
+ help="don't suppress stdout during tests",
+ )
+ args = parser.parse_args()
+
+ test_bc(silence=(not args.verbose))
diff --git a/phantom/submodules/phantom-robomimic/tests/test_bcq.py b/phantom/submodules/phantom-robomimic/tests/test_bcq.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8bd08356575e66cb779afa2aebf71b560262cd4
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/tests/test_bcq.py
@@ -0,0 +1,263 @@
+"""
+Test script for BCQ algorithms. Each test trains a variant of BCQ
+for a handful of gradient steps and tries one rollout with
+the model. Excludes stdout output by default (pass --verbose
+to see stdout output).
+"""
+import argparse
+from collections import OrderedDict
+
+import robomimic
+from robomimic.config import Config
+import robomimic.utils.test_utils as TestUtils
+from robomimic.utils.log_utils import silence_stdout
+from robomimic.utils.torch_utils import dummy_context_mgr
+
+
+def get_algo_base_config():
+ """
+ Base config for testing BCQ algorithms.
+ """
+
+ # config with basic settings for quick training run
+ config = TestUtils.get_base_config(algo_name="bcq")
+
+ # low-level obs (note that we define it here because @observation structure might vary per algorithm,
+ # for example HBC)
+ config.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.modalities.obs.rgb = []
+
+ # by default, vanilla BCQ
+ config.algo.actor.enabled = True # perturbation actor
+ config.algo.critic.distributional.enabled = False # vanilla critic training
+ config.algo.action_sampler.vae.enabled = True # action sampler is VAE
+ config.algo.action_sampler.gmm.enabled = False
+
+ return config
+
+
+def convert_config_for_images(config):
+ """
+ Modify config to use image observations.
+ """
+
+ # using high-dimensional images - don't load entire dataset into memory, and smaller batch size
+ config.train.hdf5_cache_mode = "low_dim"
+ config.train.num_data_workers = 0
+ config.train.batch_size = 16
+
+ # replace object with rgb modality
+ config.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"]
+ config.observation.modalities.obs.rgb = ["agentview_image"]
+
+ # set up visual encoders
+ config.observation.encoder.rgb.core_class = "VisualCore"
+ config.observation.encoder.rgb.core_kwargs.feature_dimension = 64
+ config.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv' # ResNet backbone for image observations (unused if no image observations)
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False # kwargs for visual core
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False
+ config.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax" # Alternate options are "SpatialMeanPool" or None (no pooling)
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0
+
+ # observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization
+ config.observation.encoder.rgb.obs_randomizer_class = None
+
+ return config
+
+
+def make_image_modifier(config_modifier):
+ """
+ turn a config modifier into its image version. Note that
+ this explicit function definition is needed for proper
+ scoping of @config_modifier
+ """
+ return lambda x: config_modifier(convert_config_for_images(x))
+
+
+# mapping from test name to config modifier functions
+MODIFIERS = OrderedDict()
+def register_mod(test_name):
+ def decorator(config_modifier):
+ MODIFIERS[test_name] = config_modifier
+ return decorator
+
+
+@register_mod("bcq-no-actor")
+def bcq_no_actor_modifier(config):
+ config.algo.actor.enabled = False
+ return config
+
+
+@register_mod("bcq-distributional")
+def bcq_distributional_modifier(config):
+ config.algo.critic.distributional.enabled = True
+ config.algo.critic.value_bounds = [-100., 100.]
+ return config
+
+
+@register_mod("bcq-as-gmm")
+def bcq_gmm_modifier(config):
+ config.algo.action_sampler.gmm.enabled = True
+ config.algo.action_sampler.vae.enabled = False
+ return config
+
+
+@register_mod("bcq-as-vae, N(0, 1) prior")
+def bcq_vae_modifier_1(config):
+ # N(0, 1) prior
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = False
+ config.algo.action_sampler.vae.prior.is_conditioned = False
+ return config
+
+
+@register_mod("bcq-as-vae, Gaussian prior (obs-independent)")
+def bcq_vae_modifier_2(config):
+ # learn parameters of Gaussian prior (obs-independent)
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = True
+ config.algo.action_sampler.vae.prior.is_conditioned = False
+ config.algo.action_sampler.vae.prior.use_gmm = False
+ config.algo.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bcq-as-vae, Gaussian prior (obs-dependent)")
+def bcq_vae_modifier_3(config):
+ # learn parameters of Gaussian prior (obs-dependent)
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = True
+ config.algo.action_sampler.vae.prior.is_conditioned = True
+ config.algo.action_sampler.vae.prior.use_gmm = False
+ config.algo.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bcq-as-vae, GMM prior (obs-independent, weights-fixed)")
+def bcq_vae_modifier_4(config):
+ # learn parameters of GMM prior (obs-independent, weights-fixed)
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = True
+ config.algo.action_sampler.vae.prior.is_conditioned = False
+ config.algo.action_sampler.vae.prior.use_gmm = True
+ config.algo.action_sampler.vae.prior.gmm_learn_weights = False
+ config.algo.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bcq-as-vae, GMM prior (obs-independent, weights-learned)")
+def bcq_vae_modifier_5(config):
+ # learn parameters of GMM prior (obs-independent, weights-learned)
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = True
+ config.algo.action_sampler.vae.prior.is_conditioned = False
+ config.algo.action_sampler.vae.prior.use_gmm = True
+ config.algo.action_sampler.vae.prior.gmm_learn_weights = True
+ config.algo.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bcq-as-vae, GMM prior (obs-dependent, weights-fixed)")
+def bcq_vae_modifier_6(config):
+ # learn parameters of GMM prior (obs-dependent, weights-fixed)
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = True
+ config.algo.action_sampler.vae.prior.is_conditioned = True
+ config.algo.action_sampler.vae.prior.use_gmm = True
+ config.algo.action_sampler.vae.prior.gmm_learn_weights = False
+ config.algo.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bcq-as-vae, GMM prior (obs-dependent, weights-learned)")
+def bcq_vae_modifier_7(config):
+ # learn parameters of GMM prior (obs-dependent, weights-learned)
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = True
+ config.algo.action_sampler.vae.prior.is_conditioned = True
+ config.algo.action_sampler.vae.prior.use_gmm = True
+ config.algo.action_sampler.vae.prior.gmm_learn_weights = True
+ config.algo.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("bcq-as-vae, uniform categorical prior")
+def bcq_vae_modifier_8(config):
+ # uniform categorical prior
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = False
+ config.algo.action_sampler.vae.prior.is_conditioned = False
+ config.algo.action_sampler.vae.prior.use_gmm = False
+ config.algo.action_sampler.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("bcq-as-vae, categorical prior (obs-independent)")
+def bcq_vae_modifier_9(config):
+ # learn parameters of categorical prior (obs-independent)
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = True
+ config.algo.action_sampler.vae.prior.is_conditioned = False
+ config.algo.action_sampler.vae.prior.use_gmm = False
+ config.algo.action_sampler.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("bcq-as-vae, categorical prior (obs-dependent)")
+def bcq_vae_modifier_10(config):
+ # learn parameters of categorical prior (obs-dependent)
+ config.algo.action_sampler.vae.enabled = True
+ config.algo.action_sampler.vae.prior.learn = True
+ config.algo.action_sampler.vae.prior.is_conditioned = True
+ config.algo.action_sampler.vae.prior.use_gmm = False
+ config.algo.action_sampler.vae.prior.use_categorical = True
+ return config
+
+
+# add image version of all tests
+image_modifiers = OrderedDict()
+for test_name in MODIFIERS:
+ lst = test_name.split("-")
+ name = "-".join(lst[:1] + ["rgb"] + lst[1:])
+ image_modifiers[name] = make_image_modifier(MODIFIERS[test_name])
+MODIFIERS.update(image_modifiers)
+
+
+# test for image crop randomization
+@register_mod("bcq-image-crop")
+def bcq_image_crop_modifier(config):
+ config = convert_config_for_images(config)
+
+ # observation randomizer class - using Crop randomizer
+ config.observation.encoder.rgb.obs_randomizer_class = "CropRandomizer"
+
+ # kwargs for observation randomizers (for the CropRandomizer, this is size and number of crops)
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.num_crops = 1
+ config.observation.encoder.rgb.obs_randomizer_kwargs.pos_enc = False
+ return config
+
+
+def test_bcq(silence=True):
+ for test_name in MODIFIERS:
+ context = silence_stdout() if silence else dummy_context_mgr()
+ with context:
+ base_config = get_algo_base_config()
+ res_str = TestUtils.test_run(base_config=base_config, config_modifier=MODIFIERS[test_name])
+ print("{}: {}".format(test_name, res_str))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--verbose",
+ action='store_true',
+ help="don't suppress stdout during tests",
+ )
+ args = parser.parse_args()
+
+ test_bcq(silence=(not args.verbose))
diff --git a/phantom/submodules/phantom-robomimic/tests/test_cql.py b/phantom/submodules/phantom-robomimic/tests/test_cql.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78c4bf222e4fb312c0b9395a65001ccbd474f83
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/tests/test_cql.py
@@ -0,0 +1,152 @@
+"""
+Test script for CQL algorithms. Each test trains a variant of CQL
+for a handful of gradient steps and tries one rollout with
+the model. Excludes stdout output by default (pass --verbose
+to see stdout output).
+"""
+import argparse
+from collections import OrderedDict
+
+import robomimic
+from robomimic.config import Config
+import robomimic.utils.test_utils as TestUtils
+from robomimic.utils.log_utils import silence_stdout
+from robomimic.utils.torch_utils import dummy_context_mgr
+
+
+def get_algo_base_config():
+ """
+ Base config for testing CQL algorithms.
+ """
+
+ # config with basic settings for quick training run
+ config = TestUtils.get_base_config(algo_name="cql")
+
+ # low-level obs (note that we define it here because @observation structure might vary per algorithm,
+ # for example HBC)
+ config.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.modalities.obs.rgb = []
+
+ # by default, vanilla CQL
+ config.algo.actor.bc_start_steps = 40 # BC training initially
+ config.algo.critic.target_q_gap = 5.0 # use automatic cql tuning
+ config.algo.actor.target_entropy = "default" # use automatic entropy tuning
+
+ # lower batch size to 100 to accomodate small test dataset
+ config.train.batch_size = 100
+
+ return config
+
+
+def convert_config_for_images(config):
+ """
+ Modify config to use image observations.
+ """
+
+ # using high-dimensional images - don't load entire dataset into memory, and smaller batch size
+ config.train.hdf5_cache_mode = "low_dim"
+ config.train.num_data_workers = 0
+ config.train.batch_size = 16
+
+ # replace object with rgb modality
+ config.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"]
+ config.observation.modalities.obs.rgb = ["agentview_image"]
+
+ # set up visual encoders
+ config.observation.encoder.rgb.core_class = "VisualCore"
+ config.observation.encoder.rgb.core_kwargs.feature_dimension = 64
+ config.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv' # ResNet backbone for image observations (unused if no image observations)
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False # kwargs for visual core
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False
+ config.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax" # Alternate options are "SpatialMeanPool" or None (no pooling)
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0
+
+ # observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization
+ config.observation.encoder.rgb.obs_randomizer_class = None
+
+ return config
+
+
+def make_image_modifier(config_modifier):
+ """
+ turn a config modifier into its image version. Note that
+ this explicit function definition is needed for proper
+ scoping of @config_modifier
+ """
+ return lambda x: config_modifier(convert_config_for_images(x))
+
+
+# mapping from test name to config modifier functions
+MODIFIERS = OrderedDict()
+def register_mod(test_name):
+ def decorator(config_modifier):
+ MODIFIERS[test_name] = config_modifier
+ return decorator
+
+
+@register_mod("cql-fixed-entropy")
+def cql_entropy_modifier(config):
+ config.algo.actor.target_entropy = None
+ return config
+
+
+@register_mod("cql-fixed-q-gap")
+def cql_q_gap_modifier(config):
+ config.algo.critic.target_q_gap = None
+ config.algo.critic.cql_weight = 1.0
+ return config
+
+
+@register_mod("cql-fixed-gaussian")
+def cql_gaussian_modifier(config):
+ config.algo.actor.net.gaussian.fixed_std = True
+ return config
+
+
+# add image version of all tests
+image_modifiers = OrderedDict()
+for test_name in MODIFIERS:
+ lst = test_name.split("-")
+ name = "-".join(lst[:1] + ["rgb"] + lst[1:])
+ image_modifiers[name] = make_image_modifier(MODIFIERS[test_name])
+MODIFIERS.update(image_modifiers)
+
+
+# test for image crop randomization
+@register_mod("cql-image-crop")
+def cql_image_crop_modifier(config):
+ config = convert_config_for_images(config)
+
+ # observation randomizer class - using Crop randomizer
+ config.observation.encoder.rgb.obs_randomizer_class = "CropRandomizer"
+
+ # kwargs for observation randomizers (for the CropRandomizer, this is size and number of crops)
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.num_crops = 1
+ config.observation.encoder.rgb.obs_randomizer_kwargs.pos_enc = False
+ return config
+
+
+def test_cql(silence=True):
+ for test_name in MODIFIERS:
+ context = silence_stdout() if silence else dummy_context_mgr()
+ with context:
+ base_config = get_algo_base_config()
+ res_str = TestUtils.test_run(base_config=base_config, config_modifier=MODIFIERS[test_name])
+ print("{}: {}".format(test_name, res_str))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--verbose",
+ action='store_true',
+ help="don't suppress stdout during tests",
+ )
+ args = parser.parse_args()
+
+ test_cql(silence=(not args.verbose))
diff --git a/phantom/submodules/phantom-robomimic/tests/test_examples.py b/phantom/submodules/phantom-robomimic/tests/test_examples.py
new file mode 100644
index 0000000000000000000000000000000000000000..6696015f18bd3ef59bf4e04fd19d41a1026dd0e3
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/tests/test_examples.py
@@ -0,0 +1,84 @@
+"""
+Tests for the provided examples in the repository. Excludes stdout output
+by default (pass --verbose to see stdout output).
+"""
+import argparse
+import traceback
+import os
+import subprocess
+import time
+import h5py
+import numpy as np
+import torch
+from collections import OrderedDict
+from termcolor import colored
+
+import robomimic
+import robomimic.utils.test_utils as TestUtils
+import robomimic.utils.torch_utils as TorchUtils
+from robomimic.utils.log_utils import silence_stdout
+from robomimic.utils.torch_utils import dummy_context_mgr
+
+
+def test_example_script(script_name, args_string, test_name, silence=True):
+ """
+ Helper function to run an example script with filename @script_name and
+ with test name @test_name (which will be printed to terminal with
+ the stderr output of the example script).
+ """
+
+ # run example script
+ stdout = subprocess.DEVNULL if silence else None
+ path_to_script = os.path.join(robomimic.__path__[0], "../examples/{}".format(script_name))
+ example_job = subprocess.Popen("python {} {}".format(path_to_script, args_string),
+ shell=True, stdout=stdout, stderr=subprocess.PIPE)
+ example_job.wait()
+
+ # get stderr output
+ out, err = example_job.communicate()
+ err = err.decode("utf-8")
+ if len(err) > 0:
+ ret = "maybe failed - stderr output below (if it's only from tqdm, the test passed)\n{}".format(err)
+ ret = colored(ret, "red")
+ else:
+ ret = colored("passed", "green")
+ print("{}: {}".format(test_name, ret))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--verbose",
+ action='store_true',
+ help="don't suppress stdout during tests",
+ )
+ args = parser.parse_args()
+
+ test_example_script(
+ script_name="simple_config.py",
+ args_string="",
+ test_name="simple-config-example",
+ silence=(not args.verbose),
+ )
+ test_example_script(
+ script_name="simple_obs_nets.py",
+ args_string="",
+ test_name="simple-obs-nets-example",
+ silence=(not args.verbose),
+ )
+ test_example_script(
+ script_name="simple_train_loop.py",
+ args_string="",
+ test_name="simple-train-loop-example",
+ silence=(not args.verbose),
+ )
+ # clear tmp model dir before running script
+ TestUtils.maybe_remove_dir(TestUtils.temp_model_dir_path())
+ test_example_script(
+ script_name="train_bc_rnn.py",
+ args_string="--debug",
+ test_name="train-bc-rnn-example",
+ silence=(not args.verbose),
+ )
+ # cleanup
+ TestUtils.maybe_remove_dir(TestUtils.temp_model_dir_path())
diff --git a/phantom/submodules/phantom-robomimic/tests/test_hbc.py b/phantom/submodules/phantom-robomimic/tests/test_hbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e55606960b85164a36d3b6d9fc35bde1c6f0fa03
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/tests/test_hbc.py
@@ -0,0 +1,184 @@
+"""
+Test script for HBC algorithm. Each test trains a variant of HBC
+for a handful of gradient steps and tries one rollout with
+the model. Excludes stdout output by default (pass --verbose
+to see stdout output).
+"""
+import argparse
+from collections import OrderedDict
+
+import robomimic
+import robomimic.utils.test_utils as TestUtils
+from robomimic.utils.log_utils import silence_stdout
+from robomimic.utils.torch_utils import dummy_context_mgr
+
+
+def get_algo_base_config():
+ """
+ Base config for testing BCQ algorithms.
+ """
+
+ # config with basic settings for quick training run
+ config = TestUtils.get_base_config(algo_name="hbc")
+
+ # low-level obs (note that we define it here because @observation structure might vary per algorithm,
+ # for example HBC)
+ config.observation.planner.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.planner.modalities.obs.rgb = []
+
+ config.observation.planner.modalities.subgoal.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.planner.modalities.subgoal.rgb = []
+
+ config.observation.actor.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.actor.modalities.obs.rgb = []
+
+ # by default, planner is deterministic prediction
+ config.algo.planner.vae.enabled = False
+
+ return config
+
+
+# mapping from test name to config modifier functions
+MODIFIERS = OrderedDict()
+def register_mod(test_name):
+ def decorator(config_modifier):
+ MODIFIERS[test_name] = config_modifier
+ return decorator
+
+
+@register_mod("hbc")
+def hbc_modifier(config):
+ # no-op
+ return config
+
+
+@register_mod("hbc-vae, N(0, 1) prior")
+def hbc_vae_modifier_1(config):
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = False
+ config.algo.planner.vae.prior.is_conditioned = False
+ return config
+
+
+@register_mod("hbc-vae, Gaussian prior (obs-independent)")
+def hbc_vae_modifier_2(config):
+ # learn parameters of Gaussian prior (obs-independent)
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = True
+ config.algo.planner.vae.prior.is_conditioned = False
+ config.algo.planner.vae.prior.use_gmm = False
+ config.algo.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("hbc-vae, Gaussian prior (obs-dependent)")
+def hbc_vae_modifier_3(config):
+ # learn parameters of Gaussian prior (obs-dependent)
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = True
+ config.algo.planner.vae.prior.is_conditioned = True
+ config.algo.planner.vae.prior.use_gmm = False
+ config.algo.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("hbc-vae, GMM prior (obs-independent, weights-fixed)")
+def hbc_vae_modifier_4(config):
+ # learn parameters of GMM prior (obs-independent, weights-fixed)
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = True
+ config.algo.planner.vae.prior.is_conditioned = False
+ config.algo.planner.vae.prior.use_gmm = True
+ config.algo.planner.vae.prior.gmm_learn_weights = False
+ config.algo.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("hbc-vae, GMM prior (obs-independent, weights-learned)")
+def hbc_vae_modifier_5(config):
+ # learn parameters of GMM prior (obs-independent, weights-learned)
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = True
+ config.algo.planner.vae.prior.is_conditioned = False
+ config.algo.planner.vae.prior.use_gmm = True
+ config.algo.planner.vae.prior.gmm_learn_weights = True
+ config.algo.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("hbc-vae, GMM prior (obs-dependent, weights-fixed)")
+def hbc_vae_modifier_6(config):
+ # learn parameters of GMM prior (obs-dependent, weights-fixed)
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = True
+ config.algo.planner.vae.prior.is_conditioned = True
+ config.algo.planner.vae.prior.use_gmm = True
+ config.algo.planner.vae.prior.gmm_learn_weights = False
+ config.algo.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("hbc-vae, GMM prior (obs-dependent, weights-learned)")
+def hbc_vae_modifier_7(config):
+ # learn parameters of GMM prior (obs-dependent, weights-learned)
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = True
+ config.algo.planner.vae.prior.is_conditioned = True
+ config.algo.planner.vae.prior.use_gmm = True
+ config.algo.planner.vae.prior.gmm_learn_weights = True
+ config.algo.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("hbc-vae, uniform categorical prior")
+def hbc_vae_modifier_8(config):
+ # uniform categorical prior
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = False
+ config.algo.planner.vae.prior.is_conditioned = False
+ config.algo.planner.vae.prior.use_gmm = False
+ config.algo.planner.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("hbc-vae, categorical prior (obs-independent)")
+def hbc_vae_modifier_9(config):
+ # learn parameters of categorical prior (obs-independent)
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = True
+ config.algo.planner.vae.prior.is_conditioned = False
+ config.algo.planner.vae.prior.use_gmm = False
+ config.algo.planner.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("hbc-vae, categorical prior (obs-dependent)")
+def hbc_vae_modifier_10(config):
+ # learn parameters of categorical prior (obs-dependent)
+ config.algo.planner.vae.enabled = True
+ config.algo.planner.vae.prior.learn = True
+ config.algo.planner.vae.prior.is_conditioned = True
+ config.algo.planner.vae.prior.use_gmm = False
+ config.algo.planner.vae.prior.use_categorical = True
+ return config
+
+
+def test_hbc(silence=True):
+ for test_name in MODIFIERS:
+ context = silence_stdout() if silence else dummy_context_mgr()
+ with context:
+ base_config = get_algo_base_config()
+ res_str = TestUtils.test_run(base_config=base_config, config_modifier=MODIFIERS[test_name])
+ print("{}: {}".format(test_name, res_str))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--verbose",
+ action='store_true',
+ help="don't suppress stdout during tests",
+ )
+ args = parser.parse_args()
+
+ test_hbc(silence=(not args.verbose))
diff --git a/phantom/submodules/phantom-robomimic/tests/test_iql.py b/phantom/submodules/phantom-robomimic/tests/test_iql.py
new file mode 100644
index 0000000000000000000000000000000000000000..e80a8f3bdc2949166d65c32de1134c85fc6c7901
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/tests/test_iql.py
@@ -0,0 +1,143 @@
+"""
+Test script for IQL algorithms. Each test trains a variant of IQL
+for a handful of gradient steps and tries one rollout with
+the model. Excludes stdout output by default (pass --verbose
+to see stdout output).
+"""
+import argparse
+from collections import OrderedDict
+
+import robomimic
+from robomimic.config import Config
+import robomimic.utils.test_utils as TestUtils
+from robomimic.utils.log_utils import silence_stdout
+from robomimic.utils.torch_utils import dummy_context_mgr
+
+
+def get_algo_base_config():
+ """
+ Base config for testing IQL algorithms.
+ """
+
+ # config with basic settings for quick training run
+ config = TestUtils.get_base_config(algo_name="iql")
+
+ # low-level obs (note that we define it here because @observation structure might vary per algorithm,
+ # for example HBC)
+ config.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.modalities.obs.rgb = []
+
+ return config
+
+
+def convert_config_for_images(config):
+ """
+ Modify config to use image observations.
+ """
+
+ # using high-dimensional images - don't load entire dataset into memory, and smaller batch size
+ config.train.hdf5_cache_mode = "low_dim"
+ config.train.num_data_workers = 0
+ config.train.batch_size = 16
+
+ # replace object with rgb modality
+ config.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"]
+ config.observation.modalities.obs.rgb = ["agentview_image"]
+
+ # set up visual encoders
+ config.observation.encoder.rgb.core_class = "VisualCore"
+ config.observation.encoder.rgb.core_kwargs.feature_dimension = 64
+ config.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv' # ResNet backbone for image observations (unused if no image observations)
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False # kwargs for visual core
+ config.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False
+ config.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax" # Alternate options are "SpatialMeanPool" or None (no pooling)
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0 # Default arguments for "SpatialSoftmax"
+ config.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0
+
+ # observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization
+ config.observation.encoder.rgb.obs_randomizer_class = None
+
+ return config
+
+
+def make_image_modifier(config_modifier):
+ """
+ turn a config modifier into its image version. Note that
+ this explicit function definition is needed for proper
+ scoping of @config_modifier
+ """
+ return lambda x: config_modifier(convert_config_for_images(x))
+
+
+# mapping from test name to config modifier functions
+MODIFIERS = OrderedDict()
+def register_mod(test_name):
+ def decorator(config_modifier):
+ MODIFIERS[test_name] = config_modifier
+ return decorator
+
+
+@register_mod("iql-gaussian")
+def iql_default_modifier(config):
+ config.algo.actor.net.type = "gaussian"
+ return config
+
+
+@register_mod("iql-gmm")
+def iql_default_modifier(config):
+ config.algo.actor.net.type = "gmm"
+ return config
+
+
+@register_mod("iql-clip-adv")
+def iql_default_modifier(config):
+ config.algo.adv.clip_adv_value = 1.0
+ return config
+
+
+# add image version of all tests
+image_modifiers = OrderedDict()
+for test_name in MODIFIERS:
+ lst = test_name.split("-")
+ name = "-".join(lst[:1] + ["rgb"] + lst[1:])
+ image_modifiers[name] = make_image_modifier(MODIFIERS[test_name])
+MODIFIERS.update(image_modifiers)
+
+
+# test for image crop randomization
+@register_mod("iql-image-crop")
+def iql_image_crop_modifier(config):
+ config = convert_config_for_images(config)
+
+ # observation randomizer class - using Crop randomizer
+ config.observation.encoder.rgb.obs_randomizer_class = "CropRandomizer"
+
+ # kwargs for observation randomizers (for the CropRandomizer, this is size and number of crops)
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 76
+ config.observation.encoder.rgb.obs_randomizer_kwargs.num_crops = 1
+ config.observation.encoder.rgb.obs_randomizer_kwargs.pos_enc = False
+ return config
+
+
+def test_iql(silence=True):
+ for test_name in MODIFIERS:
+ context = silence_stdout() if silence else dummy_context_mgr()
+ with context:
+ base_config = get_algo_base_config()
+ res_str = TestUtils.test_run(base_config=base_config, config_modifier=MODIFIERS[test_name])
+ print("{}: {}".format(test_name, res_str))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--verbose",
+ action='store_true',
+ help="don't suppress stdout during tests",
+ )
+ args = parser.parse_args()
+
+ test_iql(silence=(not args.verbose))
diff --git a/phantom/submodules/phantom-robomimic/tests/test_iris.py b/phantom/submodules/phantom-robomimic/tests/test_iris.py
new file mode 100644
index 0000000000000000000000000000000000000000..126c5c288fff150bbde55aa8570e3e4e31df3943
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/tests/test_iris.py
@@ -0,0 +1,302 @@
+"""
+Test script for IRIS algorithms. Each test trains a variant of IRIS
+for a handful of gradient steps and tries one rollout with
+the model. Excludes stdout output by default (pass --verbose
+to see stdout output).
+"""
+import argparse
+from collections import OrderedDict
+
+import robomimic
+import robomimic.utils.test_utils as TestUtils
+from robomimic.utils.log_utils import silence_stdout
+from robomimic.utils.torch_utils import dummy_context_mgr
+
+
+def get_algo_base_config():
+ """
+ Base config for testing BCQ algorithms.
+ """
+
+ # config with basic settings for quick training run
+ config = TestUtils.get_base_config(algo_name="iris")
+
+ # low-level obs (note that we define it here because @observation structure might vary per algorithm,
+ # for example iris)
+ config.observation.value_planner.planner.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.value_planner.planner.modalities.obs.rgb = []
+
+ config.observation.value_planner.planner.modalities.subgoal.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.value_planner.planner.modalities.subgoal.rgb = []
+
+ config.observation.value_planner.value.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.value_planner.value.modalities.obs.rgb = []
+
+ config.observation.actor.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
+ config.observation.actor.modalities.obs.rgb = []
+
+ # by default, basic N(0, 1) prior for both planner VAE and BCQ cVAE
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = False
+ config.algo.value_planner.planner.vae.prior.is_conditioned = False
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = False
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False
+
+ return config
+
+
+# mapping from test name to config modifier functions
+MODIFIERS = OrderedDict()
+def register_mod(test_name):
+ def decorator(config_modifier):
+ MODIFIERS[test_name] = config_modifier
+ return decorator
+
+
+@register_mod("iris")
+def iris_modifier_1(config):
+ # no-op
+ return config
+
+
+@register_mod("iris, planner vae Gaussian prior (obs-independent)")
+def iris_modifier_2(config):
+ # learn parameters of Gaussian prior (obs-independent)
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = True
+ config.algo.value_planner.planner.vae.prior.is_conditioned = False
+ config.algo.value_planner.planner.vae.prior.use_gmm = False
+ config.algo.value_planner.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, planner vae Gaussian prior (obs-dependent)")
+def iris_modifier_3(config):
+ # learn parameters of Gaussian prior (obs-dependent)
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = True
+ config.algo.value_planner.planner.vae.prior.is_conditioned = True
+ config.algo.value_planner.planner.vae.prior.use_gmm = False
+ config.algo.value_planner.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, planner vae GMM prior (obs-independent, weights-fixed)")
+def iris_modifier_4(config):
+ # learn parameters of GMM prior (obs-independent, weights-fixed)
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = True
+ config.algo.value_planner.planner.vae.prior.is_conditioned = False
+ config.algo.value_planner.planner.vae.prior.use_gmm = True
+ config.algo.value_planner.planner.vae.prior.gmm_learn_weights = False
+ config.algo.value_planner.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, planner vae GMM prior (obs-independent, weights-learned)")
+def iris_modifier_5(config):
+ # learn parameters of GMM prior (obs-independent, weights-learned)
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = True
+ config.algo.value_planner.planner.vae.prior.is_conditioned = False
+ config.algo.value_planner.planner.vae.prior.use_gmm = True
+ config.algo.value_planner.planner.vae.prior.gmm_learn_weights = True
+ config.algo.value_planner.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, planner vae GMM prior (obs-dependent, weights-fixed)")
+def iris_modifier_6(config):
+ # learn parameters of GMM prior (obs-dependent, weights-fixed)
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = True
+ config.algo.value_planner.planner.vae.prior.is_conditioned = True
+ config.algo.value_planner.planner.vae.prior.use_gmm = True
+ config.algo.value_planner.planner.vae.prior.gmm_learn_weights = False
+ config.algo.value_planner.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, planner vae GMM prior (obs-dependent, weights-learned)")
+def iris_modifier_7(config):
+ # learn parameters of GMM prior (obs-dependent, weights-learned)
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = True
+ config.algo.value_planner.planner.vae.prior.is_conditioned = True
+ config.algo.value_planner.planner.vae.prior.use_gmm = True
+ config.algo.value_planner.planner.vae.prior.gmm_learn_weights = True
+ config.algo.value_planner.planner.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, planner vae uniform categorical prior")
+def iris_modifier_8(config):
+ # uniform categorical prior
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = False
+ config.algo.value_planner.planner.vae.prior.is_conditioned = False
+ config.algo.value_planner.planner.vae.prior.use_gmm = False
+ config.algo.value_planner.planner.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("iris, planner vae categorical prior (obs-independent)")
+def iris_modifier_9(config):
+ # learn parameters of categorical prior (obs-independent)
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = True
+ config.algo.value_planner.planner.vae.prior.is_conditioned = False
+ config.algo.value_planner.planner.vae.prior.use_gmm = False
+ config.algo.value_planner.planner.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("iris, planner vae categorical prior (obs-dependent)")
+def iris_modifier_10(config):
+ # learn parameters of categorical prior (obs-dependent)
+ config.algo.value_planner.planner.vae.enabled = True
+ config.algo.value_planner.planner.vae.prior.learn = True
+ config.algo.value_planner.planner.vae.prior.is_conditioned = True
+ config.algo.value_planner.planner.vae.prior.use_gmm = False
+ config.algo.value_planner.planner.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("iris, bcq gmm")
+def iris_modifier_11(config):
+ # bcq action sampler is GMM
+ config.algo.value_planner.value.action_sampler.gmm.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.enabled = False
+ return config
+
+
+@register_mod("iris, bcq distributional")
+def iris_modifier_12(config):
+ # bcq value function is distributional
+ config.algo.value_planner.value.critic.distributional.enabled = True
+ config.algo.value_planner.value.critic.value_bounds = [-100., 100.]
+ return config
+
+@register_mod("iris, bcq cVAE Gaussian prior (obs-independent)")
+def iris_modifier_13(config):
+ # learn parameters of Gaussian prior (obs-independent)
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = True
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, bcq cVAE Gaussian prior (obs-dependent)")
+def iris_modifier_14(config):
+ # learn parameters of Gaussian prior (obs-dependent)
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = True
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = True
+ config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, bcq cVAE GMM prior (obs-independent, weights-fixed)")
+def iris_modifier_15(config):
+ # learn parameters of GMM prior (obs-independent, weights-fixed)
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = True
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = True
+ config.algo.value_planner.value.action_sampler.vae.prior.gmm_learn_weights = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, bcq cVAE GMM prior (obs-independent, weights-learned)")
+def iris_modifier_16(config):
+ # learn parameters of GMM prior (obs-independent, weights-learned)
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = True
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = True
+ config.algo.value_planner.value.action_sampler.vae.prior.gmm_learn_weights = True
+ config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, bcq cVAE GMM prior (obs-dependent, weights-fixed)")
+def iris_modifier_17(config):
+ # learn parameters of GMM prior (obs-dependent, weights-fixed)
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = True
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = True
+ config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = True
+ config.algo.value_planner.value.action_sampler.vae.prior.gmm_learn_weights = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, bcq cVAE GMM prior (obs-dependent, weights-learned)")
+def iris_modifier_18(config):
+ # learn parameters of GMM prior (obs-dependent, weights-learned)
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = True
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = True
+ config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = True
+ config.algo.value_planner.value.action_sampler.vae.prior.gmm_learn_weights = True
+ config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False
+ return config
+
+
+@register_mod("iris, bcq cVAE uniform categorical prior")
+def iris_modifier_19(config):
+ # uniform categorical prior
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = False
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("iris, bcq cVAE categorical prior (obs-independent)")
+def iris_modifier_20(config):
+ # learn parameters of categorical prior (obs-independent)
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = True
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = True
+ return config
+
+
+@register_mod("iris, bcq cVAE categorical prior (obs-dependent)")
+def iris_modifier_21(config):
+ # learn parameters of categorical prior (obs-dependent)
+ config.algo.value_planner.value.action_sampler.vae.enabled = True
+ config.algo.value_planner.value.action_sampler.vae.prior.learn = True
+ config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = True
+ config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False
+ config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = True
+ return config
+
+
+def test_iris(silence=True):
+ for test_name in MODIFIERS:
+ context = silence_stdout() if silence else dummy_context_mgr()
+ with context:
+ base_config = get_algo_base_config()
+ res_str = TestUtils.test_run(base_config=base_config, config_modifier=MODIFIERS[test_name])
+ print("{}: {}".format(test_name, res_str))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--verbose",
+ action='store_true',
+ help="don't suppress stdout during tests",
+ )
+ args = parser.parse_args()
+
+ test_iris(silence=(not args.verbose))
diff --git a/phantom/submodules/phantom-robomimic/tests/test_scripts.py b/phantom/submodules/phantom-robomimic/tests/test_scripts.py
new file mode 100644
index 0000000000000000000000000000000000000000..30ed7f6112028f1403eb64b87ff4448664358869
--- /dev/null
+++ b/phantom/submodules/phantom-robomimic/tests/test_scripts.py
@@ -0,0 +1,170 @@
+"""
+Tests for a handful of scripts. Excludes stdout output by
+default (pass --verbose to see stdout output).
+"""
+import argparse
+import traceback
+import h5py
+import numpy as np
+import torch
+from collections import OrderedDict
+from termcolor import colored
+
+import robomimic
+import robomimic.utils.test_utils as TestUtils
+import robomimic.utils.torch_utils as TorchUtils
+from robomimic.config import Config
+from robomimic.utils.log_utils import silence_stdout
+from robomimic.utils.torch_utils import dummy_context_mgr
+from robomimic.scripts.train import train
+from robomimic.scripts.playback_dataset import playback_dataset
+from robomimic.scripts.run_trained_agent import run_trained_agent
+
+
+def get_checkpoint_to_test():
+ """
+ Run a quick training run to get a checkpoint. This function runs a basic bc-image
+ training run. RGB modality is used for a harder test case for the run agent
+ script, which will need to also try writing image observations to the rollout
+ dataset.
+ """
+
+ # prepare image training run
+ config = TestUtils.get_base_config(algo_name="bc")
+
+ def image_modifier(conf):
+ # using high-dimensional images - don't load entire dataset into memory, and smaller batch size
+ conf.train.hdf5_cache_mode = "low_dim"
+ conf.train.num_data_workers = 0
+ conf.train.batch_size = 16
+
+ # replace object with rgb modality
+ conf.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"]
+ conf.observation.modalities.obs.rgb = ["agentview_image"]
+
+ # set up visual encoders
+ conf.observation.encoder.rgb.core_class = "VisualCore"
+ conf.observation.encoder.rgb.core_kwargs.feature_dimension = 64
+ conf.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv' # ResNet backbone for image observations (unused if no image observations)
+ conf.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False # kwargs for visual core
+ conf.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False
+ conf.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax" # Alternate options are "SpatialMeanPool" or None (no pooling)
+ conf.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32 # Default arguments for "SpatialSoftmax"
+ conf.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False # Default arguments for "SpatialSoftmax"
+ conf.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0 # Default arguments for "SpatialSoftmax"
+ conf.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0
+
+ # observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization
+ conf.observation.encoder.rgb.obs_randomizer_class = None
+
+ return conf
+
+ config = TestUtils.config_from_modifier(base_config=config, config_modifier=image_modifier)
+
+ # run training
+ device = TorchUtils.get_torch_device(try_to_use_cuda=True)
+ train(config, device=device)
+
+ # return checkpoint
+ ckpt_path = TestUtils.checkpoint_path_from_test_run()
+ return ckpt_path
+
+
+def test_playback_script(silence=True, use_actions=False, use_obs=False):
+ context = silence_stdout() if silence else dummy_context_mgr()
+ with context:
+
+ try:
+ # setup args and run script
+ args = argparse.Namespace()
+ args.dataset = TestUtils.example_dataset_path()
+ args.filter_key = None
+ args.n = 3 # playback 3 demonstrations
+ args.use_actions = use_actions
+ args.use_obs = use_obs
+ args.render = False
+ args.video_path = TestUtils.temp_video_path() # dump video
+ args.video_skip = 5
+ if use_obs:
+ # camera observation names
+ args.render_image_names = ["agentview_image", "robot0_eye_in_hand_image"]
+ else:
+ # camera names
+ args.render_image_names = ["agentview", "robot0_eye_in_hand"]
+ args.first = False
+ playback_dataset(args)
+
+ # indicate success
+ ret = colored("passed!", "green")
+
+ except Exception as e:
+ # indicate failure by returning error string
+ ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red")
+
+ # delete output video
+ TestUtils.maybe_remove_file(TestUtils.temp_video_path())
+
+ act_str = "-action_playback" if use_actions else ""
+ obs_str = "-obs" if use_obs else ""
+ test_name = "playback-script{}{}".format(act_str, obs_str)
+ print("{}: {}".format(test_name, ret))
+
+
+def test_run_agent_script(silence=True):
+ context = silence_stdout() if silence else dummy_context_mgr()
+ with context:
+
+ try:
+ # get a model checkpoint
+ ckpt_path = get_checkpoint_to_test()
+
+ # setup args and run script
+ args = argparse.Namespace()
+ args.agent = ckpt_path
+ args.n_rollouts = 3 # 3 rollouts
+ args.horizon = 10 # short rollouts - 10 steps
+ args.env = None
+ args.render = False
+ args.video_path = TestUtils.temp_video_path() # dump video
+ args.video_skip = 5
+ args.camera_names = ["agentview", "robot0_eye_in_hand"]
+ args.dataset_path = TestUtils.temp_dataset_path() # dump dataset
+ args.dataset_obs = True
+ args.seed = 0
+ run_trained_agent(args)
+
+ # simple sanity check for shape of image observations in rollout dataset
+ f = h5py.File(TestUtils.temp_dataset_path(), "r")
+ assert f["data/demo_1/obs/agentview_image"].shape == (10, 84, 84, 3)
+ assert f["data/demo_1/obs/agentview_image"].dtype == np.uint8
+ f.close()
+
+ # indicate success
+ ret = colored("passed!", "green")
+
+ except Exception as e:
+ # indicate failure by returning error string
+ ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red")
+
+ # delete trained model directory, output video, and output dataset
+ TestUtils.maybe_remove_dir(TestUtils.temp_model_dir_path())
+ TestUtils.maybe_remove_file(TestUtils.temp_video_path())
+ TestUtils.maybe_remove_file(TestUtils.temp_dataset_path())
+
+ test_name = "run-agent-script"
+ print("{}: {}".format(test_name, ret))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--verbose",
+ action='store_true',
+ help="don't suppress stdout during tests",
+ )
+ args = parser.parse_args()
+
+ test_playback_script(silence=(not args.verbose), use_actions=False, use_obs=False)
+ test_playback_script(silence=(not args.verbose), use_actions=True, use_obs=False)
+ test_playback_script(silence=(not args.verbose), use_actions=False, use_obs=True)
+ test_run_agent_script(silence=(not args.verbose))
diff --git a/phantom/submodules/phantom-robosuite/.gitignore b/phantom/submodules/phantom-robosuite/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a4fac789a46293cb89d30431ce350f3b617c87d6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/.gitignore
@@ -0,0 +1,117 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# dotenv
+.env
+
+# virtualenv
+.venv
+venv/
+ENV/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# mac
+.DS_Store
+
+# mujoco-key
+mjkey.txt
+
+.mujocomanip_temp_model.xml
+
+*.jpg
+.idea
+
+.pytest_cache/
+
+# private macros
+macros_private.py
diff --git a/phantom/submodules/phantom-robosuite/.pre-commit-config.yaml b/phantom/submodules/phantom-robosuite/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..32ee1fc68dc7abfff165bfe215ef82d1e2a3ea6b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/.pre-commit-config.yaml
@@ -0,0 +1,11 @@
+repos:
+ - repo: https://github.com/psf/black
+ rev: 22.10.0 # Replace by any tag/version: https://github.com/psf/black/tags
+ hooks:
+ - id: black
+ language_version: python3 # Should be a command that runs python3.6+
+ - repo: https://github.com/pycqa/isort
+ rev: 5.10.1
+ hooks:
+ - id: isort
+ name: isort (python)
diff --git a/phantom/submodules/phantom-robosuite/AUTHORS b/phantom/submodules/phantom-robosuite/AUTHORS
new file mode 100644
index 0000000000000000000000000000000000000000..281ace4009f198d4e5c46f7dcbd98decc20fc943
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/AUTHORS
@@ -0,0 +1,31 @@
+# This file contains an official list of authors of this framework.
+
+# Names should be added to this file as:
+# Name or Organization
+# The email address is not required for organizations.
+
+Core Team
+Yuke Zhu
+Josiah Wong
+Ajay Mandlekar
+Roberto Martín-Martín
+Abhishek Joshi
+Soroush Nasiriany
+Yifeng Zhu
+
+Past Contributors
+Jiren Zhu
+Jim (Linxi) Fan
+Orien Zeng
+Anchit Gupta
+Zihua Liu
+Joan Creus-Costa
+Anchit Gupta
+Michelle Lee
+Andrew Kondrich
+Rachel Gardner
+Jonathan Booher
+Danfei Xu
+Rachel Gardner
+Albert Tung
+Divyansh Jha
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/CONTRIBUTING.md b/phantom/submodules/phantom-robosuite/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..6c85921f8091257d0a7c564aef74d0f01206831b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/CONTRIBUTING.md
@@ -0,0 +1,47 @@
+How to Contribute
+=================
+
+We are so happy to see you reading this page!
+
+Our team wholeheartedly welcomes the community to contribute to robosuite. Contributions from members of the community will help ensure the long-term success of this project. Before you plan to make contributions, here are important resources to get started with:
+
+- Read the robosuite [documentation](https://robosuite.ai/docs/overview.html) and [whitepaper](https://robosuite.ai/assets/whitepaper.pdf)
+- Check our latest status from existing [issues](https://github.com/ARISE-Initiative/robosuite/issues), [pull requests](https://github.com/ARISE-Initiative/robosuite/pulls), and [branches](https://github.com/ARISE-Initiative/robosuite/branches) and avoid duplicate efforts
+- Join our [ARISE Slack](https://ariseinitiative.slack.com) workspace for technical discussions. Please [email us](mailto:yukez@cs.utexas.edu) to be added to the workspace.
+
+We encourage the community to make four major types of contributions:
+
+- **Bug fixes**: Address open issues and fix bugs presented in the `master` branch
+- **Environment designs:** Design new environments and add them to our existing set of [environments](https://github.com/ARISE-Initiative/robosuite/tree/master/robosuite/environments)
+- **Additional assets:** Incorporate new [models](https://github.com/ARISE-Initiative/robosuite/tree/master/robosuite/models) and functionalities of robots, grippers, objects, and workspaces
+- **New functionalities:** Implement new features, such as dynamics randomization, rendering tools, new controllers, etc.
+
+Testing
+-------
+Before submitting your contributions, make sure that the changes do not break existing functionalities.
+We have a handful of [tests](https://github.com/ARISE-Initiative/robosuite/tree/master/tests) for verifying the correctness of the code.
+You can run all the tests with the following command in the root folder of robosuite. Make sure that it does not throw any error before you proceed to the next step.
+```sh
+$ python -m pytest
+```
+
+Submission
+----------
+Please read the coding conventions below and make sure that your code is consistent with ours. We use the [black](https://github.com/psf/black) and [isort](https://github.com/pycqa/isort) as the [pre-commit](https://pre-commit.com/) hooks to format the source code before code review. To install these hooks, first `pip install pre-commit; pre-commit install` to set them up. Once set up, these hooks should be automatically triggered when committing new changes. If you want to manually check the format of the codes that have already been committed, please run `pre-commit run --all-files` in the project folder.
+
+When making a contribution, make a [pull request](https://docs.github.com/en/free-pro-team@latest/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests)
+to robosuite with an itemized list of what you have done. When you submit a pull request, it is immensely helpful to include example script(s) that showcase the proposed changes and highlight any new APIs.
+We always love to see more test coverage. When it is appropriate, add a new test to the [tests](https://github.com/ARISE-Initiative/robosuite/tree/master/tests) folder for checking the correctness of your code.
+
+Coding Conventions
+------------------
+In addition to the pre-commit hooks, we value readability and adhere to the following coding conventions:
+- Indent using four spaces (soft tabs)
+- Always put spaces after list items and method parameters (e.g., `[1, 2, 3]` rather than `[1,2,3]`), and around operators and hash arrows (e.g., `x += 1` rather than `x+=1`)
+- Use the [Google Python Style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) for the docstrings
+- For scripts such as in [demos](https://github.com/ARISE-Initiative/robosuite/tree/master/robosuite/demos) and [tests](https://github.com/ARISE-Initiative/robosuite/tree/master/tests),
+ include a docstring at the top of the file that describes the high-level purpose of the script and/or instructions on how to use the scripts (if relevant).
+
+We look forward to your contributions. Thanks!
+
+The robosuite core team
diff --git a/phantom/submodules/phantom-robosuite/LICENSE b/phantom/submodules/phantom-robosuite/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e20189628cc282626830a7855edf722de38dc700
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/LICENSE
@@ -0,0 +1,28 @@
+MIT License
+
+Copyright (c) 2022 Stanford Vision and Learning Lab and UT Robot Perception and Learning Lab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+This software includes the partial implementation of Deepmind Mujoco https://github.com/deepmind/mujoco.
+Deepmind Mujoco is licensed under the Apache License, Version 2.0 (the "License");
+you may not use the files except in compliance with the License.
+
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
diff --git a/phantom/submodules/phantom-robosuite/MANIFEST.in b/phantom/submodules/phantom-robosuite/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..fa3a4c99e41758a1e0515f0d281e07a36374add4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/MANIFEST.in
@@ -0,0 +1,4 @@
+recursive-include robosuite/controllers/config/ *
+recursive-include robosuite/demos *
+recursive-include robosuite/models/assets/ *
+recursive-include robosuite/scripts *
diff --git a/phantom/submodules/phantom-robosuite/README.md b/phantom/submodules/phantom-robosuite/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..38e02567f34ddfbb58e0bf25525cf8cde259b685
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/README.md
@@ -0,0 +1,47 @@
+# robosuite
+
+
+
+[**[Homepage]**](https://robosuite.ai/) [**[White Paper]**](https://arxiv.org/abs/2009.12293) [**[Documentations]**](https://robosuite.ai/docs/overview.html) [**[ARISE Initiative]**](https://github.com/ARISE-Initiative)
+
+-------
+## Latest Updates
+- [11/15/2022] **v1.4**: Backend migration to DeepMind's official [MuJoCo Python binding](https://github.com/deepmind/mujoco), robot textures, and bug fixes :robot: [[release notes]](https://github.com/ARISE-Initiative/robosuite/releases/tag/v1.4.0) [[documentation]](http://robosuite.ai/docs/v1.4/)
+
+- [10/19/2021] **v1.3**: Ray tracing and physically based rendering tools :sparkles: and access to additional vision modalities 🎥 [[video spotlight]](https://www.youtube.com/watch?v=2xesly6JrQ8) [[release notes]](https://github.com/ARISE-Initiative/robosuite/releases/tag/v1.3) [[documentation]](http://robosuite.ai/docs/v1.3/)
+
+- [02/17/2021] **v1.2**: Added observable sensor models :eyes: and dynamics randomization :game_die: [[release notes]](https://github.com/ARISE-Initiative/robosuite/releases/tag/v1.2)
+
+- [12/17/2020] **v1.1**: Refactored infrastructure and standardized model classes for much easier environment prototyping :wrench: [[release notes]](https://github.com/ARISE-Initiative/robosuite/releases/tag/v1.1)
+
+-------
+
+**robosuite** is a simulation framework powered by the [MuJoCo](http://mujoco.org/) physics engine for robot learning. It also offers a suite of benchmark environments for reproducible research. The current release (v1.4) features long-term support with the official MuJoCo binding from DeepMind. This project is part of the broader [Advancing Robot Intelligence through Simulated Environments (ARISE) Initiative](https://github.com/ARISE-Initiative), with the aim of lowering the barriers of entry for cutting-edge research at the intersection of AI and Robotics.
+
+Data-driven algorithms, such as reinforcement learning and imitation learning, provide a powerful and generic tool in robotics. These learning paradigms, fueled by new advances in deep learning, have achieved some exciting successes in a variety of robot control problems. However, the challenges of reproducibility and the limited accessibility of robot hardware (especially during a pandemic) have impaired research progress. The overarching goal of **robosuite** is to provide researchers with:
+
+* a standardized set of benchmarking tasks for rigorous evaluation and algorithm development;
+* a modular design that offers great flexibility to design new robot simulation environments;
+* a high-quality implementation of robot controllers and off-the-shelf learning algorithms to lower the barriers to entry.
+
+This framework was originally developed since late 2017 by researchers in [Stanford Vision and Learning Lab](http://svl.stanford.edu) (SVL) as an internal tool for robot learning research. Now it is actively maintained and used for robotics research projects in SVL and the [UT Robot Perception and Learning Lab](http://rpl.cs.utexas.edu) (RPL). We welcome community contributions to this project. For details please check out our [contributing guidelines](CONTRIBUTING.md).
+
+This release of **robosuite** contains seven robot models, eight gripper models, six controller modes, and nine standardized tasks. It also offers a modular design of APIs for building new environments with procedural generation. We highlight these primary features below:
+
+* **standardized tasks**: a set of standardized manipulation tasks of large diversity and varying complexity and RL benchmarking results for reproducible research;
+* **procedural generation**: modular APIs for programmatically creating new environments and new tasks as combinations of robot models, arenas, and parameterized 3D objects;
+* **robot controllers**: a selection of controller types to command the robots, such as joint-space velocity control, inverse kinematics control, operational space control, and 3D motion devices for teleoperation;
+* **multi-modal sensors**: heterogeneous types of sensory signals, including low-level physical states, RGB cameras, depth maps, and proprioception;
+* **human demonstrations**: utilities for collecting human demonstrations, replaying demonstration datasets, and leveraging demonstration data for learning. Check out our sister project [robomimic](https://arise-initiative.github.io/robomimic-web/);
+* **photorealistic rendering**: integration with advanced graphics tools that provide real-time photorealistic renderings of simulated scenes.
+
+## Citation
+Please cite [**robosuite**](https://robosuite.ai) if you use this framework in your publications:
+```bibtex
+@inproceedings{robosuite2020,
+ title={robosuite: A Modular Simulation Framework and Benchmark for Robot Learning},
+ author={Yuke Zhu and Josiah Wong and Ajay Mandlekar and Roberto Mart\'{i}n-Mart\'{i}n and Abhishek Joshi and Soroush Nasiriany and Yifeng Zhu},
+ booktitle={arXiv preprint arXiv:2009.12293},
+ year={2020}
+}
+```
diff --git a/phantom/submodules/phantom-robosuite/pyproject.toml b/phantom/submodules/phantom-robosuite/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..6871e11a49971eef9e89416f8b5398efc4a1b522
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/pyproject.toml
@@ -0,0 +1,15 @@
+[tool.black]
+line-length = 120
+target-version = ["py36", "py37", "py38"]
+extend-exclude = "robosuite/((models/assets)|(controllers/config))"
+
+[tool.isort]
+profile = "black"
+line_length = 120
+skip = ["__init__.py"]
+filter_files = true
+py_version = "all"
+extend_skip = [
+ "robosuite/models/assets",
+ "robosuite/controllers/config",
+]
diff --git a/phantom/submodules/phantom-robosuite/requirements-extra.txt b/phantom/submodules/phantom-robosuite/requirements-extra.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5f607ba9fdb88f725ee888089104e001c3e61a4f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/requirements-extra.txt
@@ -0,0 +1,14 @@
+# required for IK controllers
+pybullet-svl>=3.1.6.4
+
+# required for GymWrapper
+gymnasium
+
+# macOS only
+hidapi
+
+# required for demonstration utils
+h5py
+
+# required for nvisii renderer
+open3d
diff --git a/phantom/submodules/phantom-robosuite/requirements.txt b/phantom/submodules/phantom-robosuite/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6e1198b1ab1f5a7f19c6f1fc2ba7338438cf718
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/requirements.txt
@@ -0,0 +1 @@
+-e .
diff --git a/phantom/submodules/phantom-robosuite/robosuite/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88a03d99d94563900bae6711aeeca06961df0a3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/__init__.py
@@ -0,0 +1,29 @@
+from robosuite.environments.base import make
+
+# Manipulation environments
+from robosuite.environments.manipulation.lift import Lift
+from robosuite.environments.manipulation.stack import Stack
+from robosuite.environments.manipulation.nut_assembly import NutAssembly
+from robosuite.environments.manipulation.pick_place import PickPlace
+from robosuite.environments.manipulation.door import Door
+from robosuite.environments.manipulation.wipe import Wipe
+from robosuite.environments.manipulation.tool_hang import ToolHang
+from robosuite.environments.manipulation.two_arm_lift import TwoArmLift
+from robosuite.environments.manipulation.two_arm_peg_in_hole import TwoArmPegInHole
+from robosuite.environments.manipulation.two_arm_handover import TwoArmHandover
+from robosuite.environments.manipulation.two_arm_transport import TwoArmTransport
+from robosuite.environments.manipulation.phantom import Phantom
+from robosuite.environments.manipulation.phantom_bimanual import PhantomBimanual
+
+from robosuite.environments import ALL_ENVIRONMENTS
+from robosuite.controllers import ALL_CONTROLLERS, load_controller_config
+from robosuite.robots import ALL_ROBOTS
+from robosuite.models.grippers import ALL_GRIPPERS
+
+__version__ = "1.4.1"
+__logo__ = """
+ ; / ,--.
+ ["] ["] ,< |__**|
+ /[_]\ [~]\/ |// |
+ ] [ OOO /o|__|
+"""
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..55f2616277d88f49dccb1493e301b4a8d523eca8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/__init__.py
@@ -0,0 +1,17 @@
+from .controller_factory import controller_factory, load_controller_config, reset_controllers, get_pybullet_server
+from .osc import OperationalSpaceController
+from .joint_pos import JointPositionController
+from .joint_vel import JointVelocityController
+from .joint_tor import JointTorqueController
+
+
+CONTROLLER_INFO = {
+ "JOINT_VELOCITY": "Joint Velocity",
+ "JOINT_TORQUE": "Joint Torque",
+ "JOINT_POSITION": "Joint Position",
+ "OSC_POSITION": "Operational Space Control (Position Only)",
+ "OSC_POSE": "Operational Space Control (Position + Orientation)",
+ "IK_POSE": "Inverse Kinematics Control (Position + Orientation) (Note: must have PyBullet installed)",
+}
+
+ALL_CONTROLLERS = CONTROLLER_INFO.keys()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/base_controller.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/base_controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..895952a5a4e300ae58ec998cc379853d06fff689
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/base_controller.py
@@ -0,0 +1,273 @@
+import abc
+from collections.abc import Iterable
+
+import mujoco
+import numpy as np
+
+import robosuite.macros as macros
+
+
+class Controller(object, metaclass=abc.ABCMeta):
+ """
+ General controller interface.
+
+ Requires reference to mujoco sim object, eef_name of specific robot, relevant joint_indexes to that robot, and
+ whether an initial_joint is used for nullspace torques or not
+
+ Args:
+ sim (MjSim): Simulator instance this controller will pull robot state updates from
+
+ eef_name (str): Name of controlled robot arm's end effector (from robot XML)
+
+ joint_indexes (dict): Each key contains sim reference indexes to relevant robot joint information, namely:
+
+ :`'joints'`: list of indexes to relevant robot joints
+ :`'qpos'`: list of indexes to relevant robot joint positions
+ :`'qvel'`: list of indexes to relevant robot joint velocities
+
+ actuator_range (2-tuple of array of float): 2-Tuple (low, high) representing the robot joint actuator range
+ """
+
+ def __init__(
+ self,
+ sim,
+ eef_name,
+ joint_indexes,
+ actuator_range,
+ ):
+
+ # Actuator range
+ self.actuator_min = actuator_range[0]
+ self.actuator_max = actuator_range[1]
+
+ # Attributes for scaling / clipping inputs to outputs
+ self.action_scale = None
+ self.action_input_transform = None
+ self.action_output_transform = None
+
+ # Private property attributes
+ self.control_dim = None
+ self.output_min = None
+ self.output_max = None
+ self.input_min = None
+ self.input_max = None
+
+ # mujoco simulator state
+ self.sim = sim
+ self.model_timestep = macros.SIMULATION_TIMESTEP
+ self.eef_name = eef_name
+ self.joint_index = joint_indexes["joints"]
+ self.qpos_index = joint_indexes["qpos"]
+ self.qvel_index = joint_indexes["qvel"]
+
+ # robot states
+ self.ee_pos = None
+ self.ee_ori_mat = None
+ self.ee_pos_vel = None
+ self.ee_ori_vel = None
+ self.joint_pos = None
+ self.joint_vel = None
+
+ # dynamics and kinematics
+ self.J_pos = None
+ self.J_ori = None
+ self.J_full = None
+ self.mass_matrix = None
+
+ # Joint dimension
+ self.joint_dim = len(joint_indexes["joints"])
+
+ # Torques being outputted by the controller
+ self.torques = None
+
+ # Update flag to prevent redundant update calls
+ self.new_update = True
+
+ # Move forward one timestep to propagate updates before taking first update
+ self.sim.forward()
+
+ # Initialize controller by updating internal state and setting the initial joint, pos, and ori
+ self.update()
+ self.initial_joint = self.joint_pos
+ self.initial_ee_pos = self.ee_pos
+ self.initial_ee_ori_mat = self.ee_ori_mat
+
+ @abc.abstractmethod
+ def run_controller(self):
+ """
+ Abstract method that should be implemented in all subclass controllers, and should convert a given action
+ into torques (pre gravity compensation) to be executed on the robot.
+ Additionally, resets the self.new_update flag so that the next self.update call will occur
+ """
+ self.new_update = True
+
+ def scale_action(self, action):
+ """
+ Clips @action to be within self.input_min and self.input_max, and then re-scale the values to be within
+ the range self.output_min and self.output_max
+
+ Args:
+ action (Iterable): Actions to scale
+
+ Returns:
+ np.array: Re-scaled action
+ """
+
+ if self.action_scale is None:
+ self.action_scale = abs(self.output_max - self.output_min) / abs(self.input_max - self.input_min)
+ self.action_output_transform = (self.output_max + self.output_min) / 2.0
+ self.action_input_transform = (self.input_max + self.input_min) / 2.0
+ action = np.clip(action, self.input_min, self.input_max)
+ transformed_action = (action - self.action_input_transform) * self.action_scale + self.action_output_transform
+
+ return transformed_action
+
+ def update(self, force=False):
+ """
+ Updates the state of the robot arm, including end effector pose / orientation / velocity, joint pos/vel,
+ jacobian, and mass matrix. By default, since this is a non-negligible computation, multiple redundant calls
+ will be ignored via the self.new_update attribute flag. However, if the @force flag is set, the update will
+ occur regardless of that state of self.new_update. This base class method of @run_controller resets the
+ self.new_update flag
+
+ Args:
+ force (bool): Whether to force an update to occur or not
+ """
+
+ # Only run update if self.new_update or force flag is set
+ if self.new_update or force:
+ self.sim.forward()
+
+ self.ee_pos = np.array(self.sim.data.site_xpos[self.sim.model.site_name2id(self.eef_name)])
+ self.ee_ori_mat = np.array(
+ self.sim.data.site_xmat[self.sim.model.site_name2id(self.eef_name)].reshape([3, 3])
+ )
+ self.ee_pos_vel = np.array(self.sim.data.get_site_xvelp(self.eef_name))
+ self.ee_ori_vel = np.array(self.sim.data.get_site_xvelr(self.eef_name))
+
+ self.joint_pos = np.array(self.sim.data.qpos[self.qpos_index])
+ self.joint_vel = np.array(self.sim.data.qvel[self.qvel_index])
+
+ self.J_pos = np.array(self.sim.data.get_site_jacp(self.eef_name).reshape((3, -1))[:, self.qvel_index])
+ self.J_ori = np.array(self.sim.data.get_site_jacr(self.eef_name).reshape((3, -1))[:, self.qvel_index])
+ self.J_full = np.array(np.vstack([self.J_pos, self.J_ori]))
+
+ mass_matrix = np.ndarray(shape=(self.sim.model.nv, self.sim.model.nv), dtype=np.float64, order="C")
+ mujoco.mj_fullM(self.sim.model._model, mass_matrix, self.sim.data.qM)
+ mass_matrix = np.reshape(mass_matrix, (len(self.sim.data.qvel), len(self.sim.data.qvel)))
+ self.mass_matrix = mass_matrix[self.qvel_index, :][:, self.qvel_index]
+
+ # Clear self.new_update
+ self.new_update = False
+
+ def update_base_pose(self, base_pos, base_ori):
+ """
+ Optional function to implement in subclass controllers that will take in @base_pos and @base_ori and update
+ internal configuration to account for changes in the respective states. Useful for controllers e.g. IK, which
+ is based on pybullet and requires knowledge of simulator state deviations between pybullet and mujoco
+
+ Args:
+ base_pos (3-tuple): x,y,z position of robot base in mujoco world coordinates
+ base_ori (4-tuple): x,y,z,w orientation or robot base in mujoco world coordinates
+ """
+ pass
+
+ def update_initial_joints(self, initial_joints):
+ """
+ Updates the internal attribute self.initial_joints. This is useful for updating changes in controller-specific
+ behavior, such as with OSC where self.initial_joints is used for determine nullspace actions
+
+ This function can also be extended by subclassed controllers for additional controller-specific updates
+
+ Args:
+ initial_joints (Iterable): Array of joint position values to update the initial joints
+ """
+ self.initial_joint = np.array(initial_joints)
+ self.update(force=True)
+ self.initial_ee_pos = self.ee_pos
+ self.initial_ee_ori_mat = self.ee_ori_mat
+
+ def clip_torques(self, torques):
+ """
+ Clips the torques to be within the actuator limits
+
+ Args:
+ torques (Iterable): Torques to clip
+
+ Returns:
+ np.array: Clipped torques
+ """
+ return np.clip(torques, self.actuator_min, self.actuator_max)
+
+ def reset_goal(self):
+ """
+ Resets the goal -- usually by setting to the goal to all zeros, but in some cases may be different (e.g.: OSC)
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def nums2array(nums, dim):
+ """
+ Convert input @nums into numpy array of length @dim. If @nums is a single number, broadcasts it to the
+ corresponding dimension size @dim before converting into a numpy array
+
+ Args:
+ nums (numeric or Iterable): Either single value or array of numbers
+ dim (int): Size of array to broadcast input to env.sim.data.actuator_force
+
+ Returns:
+ np.array: Array filled with values specified in @nums
+ """
+ # First run sanity check to make sure no strings are being inputted
+ if isinstance(nums, str):
+ raise TypeError("Error: Only numeric inputs are supported for this function, nums2array!")
+
+ # Check if input is an Iterable, if so, we simply convert the input to np.array and return
+ # Else, input is a single value, so we map to a numpy array of correct size and return
+ return np.array(nums) if isinstance(nums, Iterable) else np.ones(dim) * nums
+
+ @property
+ def torque_compensation(self):
+ """
+ Gravity compensation for this robot arm
+
+ Returns:
+ np.array: torques
+ """
+ return self.sim.data.qfrc_bias[self.qvel_index]
+
+ @property
+ def actuator_limits(self):
+ """
+ Torque limits for this controller
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum actuator torques
+ - (np.array) maximum actuator torques
+ """
+ return self.actuator_min, self.actuator_max
+
+ @property
+ def control_limits(self):
+ """
+ Limits over this controller's action space, which defaults to input min/max
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum action values
+ - (np.array) maximum action values
+ """
+ return self.input_min, self.input_max
+
+ @property
+ def name(self):
+ """
+ Name of this controller
+
+ Returns:
+ str: controller name
+ """
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_baxter.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_baxter.json
new file mode 100644
index 0000000000000000000000000000000000000000..960d52fcd3389ee6cbfff72d6cb99be38bd02533
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_baxter.json
@@ -0,0 +1,11 @@
+{
+ "type": "JOINT_VELOCITY",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.5,
+ "output_min": -0.5,
+ "kp": 0.03,
+ "velocity_limits": [-1, 1],
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_iiwa.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_iiwa.json
new file mode 100644
index 0000000000000000000000000000000000000000..73d9018da0c99eb3edeaded02ee54723697b53da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_iiwa.json
@@ -0,0 +1,11 @@
+{
+ "type" : "JOINT_VELOCITY",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.5,
+ "output_min": -0.5,
+ "kp": 0.03,
+ "velocity_limits": [-1,1],
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_jaco.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_jaco.json
new file mode 100644
index 0000000000000000000000000000000000000000..73d9018da0c99eb3edeaded02ee54723697b53da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_jaco.json
@@ -0,0 +1,11 @@
+{
+ "type" : "JOINT_VELOCITY",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.5,
+ "output_min": -0.5,
+ "kp": 0.03,
+ "velocity_limits": [-1,1],
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_kinova3.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_kinova3.json
new file mode 100644
index 0000000000000000000000000000000000000000..73d9018da0c99eb3edeaded02ee54723697b53da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_kinova3.json
@@ -0,0 +1,11 @@
+{
+ "type" : "JOINT_VELOCITY",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.5,
+ "output_min": -0.5,
+ "kp": 0.03,
+ "velocity_limits": [-1,1],
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_panda.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_panda.json
new file mode 100644
index 0000000000000000000000000000000000000000..73d9018da0c99eb3edeaded02ee54723697b53da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_panda.json
@@ -0,0 +1,11 @@
+{
+ "type" : "JOINT_VELOCITY",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.5,
+ "output_min": -0.5,
+ "kp": 0.03,
+ "velocity_limits": [-1,1],
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_sawyer.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_sawyer.json
new file mode 100644
index 0000000000000000000000000000000000000000..73d9018da0c99eb3edeaded02ee54723697b53da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_sawyer.json
@@ -0,0 +1,11 @@
+{
+ "type" : "JOINT_VELOCITY",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.5,
+ "output_min": -0.5,
+ "kp": 0.03,
+ "velocity_limits": [-1,1],
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_ur5e.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_ur5e.json
new file mode 100644
index 0000000000000000000000000000000000000000..73d9018da0c99eb3edeaded02ee54723697b53da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/default_ur5e.json
@@ -0,0 +1,11 @@
+{
+ "type" : "JOINT_VELOCITY",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.5,
+ "output_min": -0.5,
+ "kp": 0.03,
+ "velocity_limits": [-1,1],
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/ik_pose.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/ik_pose.json
new file mode 100644
index 0000000000000000000000000000000000000000..45a0223f202b95020f5f94ecc78118b0c286e3c1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/ik_pose.json
@@ -0,0 +1,7 @@
+{
+ "type" : "IK_POSE",
+ "ik_pos_limit": 0.02,
+ "ik_ori_limit": 0.05,
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/joint_position.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/joint_position.json
new file mode 100644
index 0000000000000000000000000000000000000000..86cb4f576fc13a738c83fe93938482362a9f4284
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/joint_position.json
@@ -0,0 +1,15 @@
+{
+ "type": "JOINT_POSITION",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.05,
+ "output_min": -0.05,
+ "kp": 50,
+ "damping_ratio": 1,
+ "impedance_mode": "fixed",
+ "kp_limits": [0, 300],
+ "damping_ratio_limits": [0, 10],
+ "qpos_limits": null,
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/joint_torque.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/joint_torque.json
new file mode 100644
index 0000000000000000000000000000000000000000..eab76b8b3832530ec3c522b192260e77547f5f7e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/joint_torque.json
@@ -0,0 +1,10 @@
+{
+ "type": "JOINT_TORQUE",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.1,
+ "output_min": -0.1,
+ "torque_limits": null,
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/joint_velocity.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/joint_velocity.json
new file mode 100644
index 0000000000000000000000000000000000000000..4d8752a3a26117234f7185487134a923a62e5846
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/joint_velocity.json
@@ -0,0 +1,11 @@
+{
+ "type" : "JOINT_VELOCITY",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": 0.5,
+ "output_min": -0.5,
+ "kp": 3.0,
+ "velocity_limits": [-1,1],
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/osc_pose.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/osc_pose.json
new file mode 100644
index 0000000000000000000000000000000000000000..8dc645e44bb13ba6806e7a74af52d1efaefc79e8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/osc_pose.json
@@ -0,0 +1,18 @@
+{
+ "type": "OSC_POSE",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": [0.05, 0.05, 0.05, 0.5, 0.5, 0.5],
+ "output_min": [-0.05, -0.05, -0.05, -0.5, -0.5, -0.5],
+ "kp": 150,
+ "damping_ratio": 1,
+ "impedance_mode": "fixed",
+ "kp_limits": [0, 300],
+ "damping_ratio_limits": [0, 10],
+ "position_limits": null,
+ "orientation_limits": null,
+ "uncouple_pos_ori": true,
+ "control_delta": true,
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/config/osc_position.json b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/osc_position.json
new file mode 100644
index 0000000000000000000000000000000000000000..8e1fd3b164f78ca75b25fb90ccb0bb9fc8b22d8e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/config/osc_position.json
@@ -0,0 +1,16 @@
+{
+ "type": "OSC_POSITION",
+ "input_max": 1,
+ "input_min": -1,
+ "output_max": [0.05, 0.05, 0.05],
+ "output_min": [-0.05, -0.05, -0.05],
+ "kp": 150,
+ "damping_ratio": 1,
+ "impedance_mode": "fixed",
+ "kp_limits": [0, 300],
+ "damping_ratio_limits": [0, 10],
+ "position_limits": null,
+ "control_delta": true,
+ "interpolation": null,
+ "ramp_ratio": 0.2
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/controller_factory.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/controller_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..12eac96736b0dd4d181bf690b460cbf302b52162
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/controller_factory.py
@@ -0,0 +1,168 @@
+"""
+Set of functions that streamline controller initialization process
+"""
+import json
+import os
+from copy import deepcopy
+
+import numpy as np
+
+from .interpolators.linear_interpolator import LinearInterpolator
+from .joint_pos import JointPositionController
+from .joint_tor import JointTorqueController
+from .joint_vel import JointVelocityController
+from .osc import OperationalSpaceController
+
+# Global var for linking pybullet server to multiple ik controller instances if necessary
+pybullet_server = None
+
+
+def reset_controllers():
+ """
+ Global function for doing one-time clears and restarting of any global controller-related
+ specifics before re-initializing each individual controller again
+ """
+ global pybullet_server
+ # Disconnect and reconnect to pybullet server if it exists
+ if pybullet_server is not None:
+ pybullet_server.disconnect()
+ pybullet_server.connect()
+
+
+def get_pybullet_server():
+ """
+ Getter to return reference to pybullet server module variable
+
+ Returns:
+ PyBulletServer: Server instance running PyBullet
+ """
+ global pybullet_server
+ return pybullet_server
+
+
+def load_controller_config(custom_fpath=None, default_controller=None):
+ """
+ Utility function that loads the desired controller and returns the loaded configuration as a dict
+
+ If @default_controller is specified, any value inputted to @custom_fpath is overridden and the default controller
+ configuration is automatically loaded. See specific arg description below for available default controllers.
+
+ Args:
+ custom_fpath (str): Absolute filepath to the custom controller configuration .json file to be loaded
+ default_controller (str): If specified, overrides @custom_fpath and loads a default configuration file for the
+ specified controller.
+ Choices are: {"JOINT_POSITION", "JOINT_TORQUE", "JOINT_VELOCITY", "OSC_POSITION", "OSC_POSE", "IK_POSE"}
+
+ Returns:
+ dict: Controller configuration
+
+ Raises:
+ AssertionError: [Unknown default controller name]
+ AssertionError: [No controller specified]
+ """
+ # First check if default controller is not None; if it is not, load the appropriate controller
+ if default_controller is not None:
+
+ # Assert that requested default controller is in the available default controllers
+ from robosuite.controllers import ALL_CONTROLLERS
+
+ assert (
+ default_controller in ALL_CONTROLLERS
+ ), "Error: Unknown default controller specified. Requested {}, " "available controllers: {}".format(
+ default_controller, list(ALL_CONTROLLERS)
+ )
+
+ # Store the default controller config fpath associated with the requested controller
+ custom_fpath = os.path.join(
+ os.path.dirname(__file__), "..", "controllers/config/{}.json".format(default_controller.lower())
+ )
+
+ # Assert that the fpath to load the controller is not empty
+ assert custom_fpath is not None, "Error: Either custom_fpath or default_controller must be specified!"
+
+ # Attempt to load the controller
+ try:
+ with open(custom_fpath) as f:
+ controller_config = json.load(f)
+ except FileNotFoundError:
+ print("Error opening controller filepath at: {}. " "Please check filepath and try again.".format(custom_fpath))
+
+ # Return the loaded controller
+ return controller_config
+
+
+def controller_factory(name, params):
+ """
+ Generator for controllers
+
+ Creates a Controller instance with the provided @name and relevant @params.
+
+ Args:
+ name (str): the name of the controller. Must be one of: {JOINT_POSITION, JOINT_TORQUE, JOINT_VELOCITY,
+ OSC_POSITION, OSC_POSE, IK_POSE}
+ params (dict): dict containing the relevant params to pass to the controller
+ sim (MjSim): Mujoco sim reference to pass to the controller
+
+ Returns:
+ Controller: Controller instance
+
+ Raises:
+ ValueError: [unknown controller]
+ """
+
+ interpolator = None
+ if params["interpolation"] == "linear":
+ interpolator = LinearInterpolator(
+ ndim=params["ndim"],
+ controller_freq=(1 / params["sim"].model.opt.timestep),
+ policy_freq=params["policy_freq"],
+ ramp_ratio=params["ramp_ratio"],
+ )
+
+ if name == "OSC_POSE":
+ ori_interpolator = None
+ if interpolator is not None:
+ interpolator.set_states(dim=3) # EE control uses dim 3 for pos and ori each
+ ori_interpolator = deepcopy(interpolator)
+ ori_interpolator.set_states(ori="euler")
+ params["control_ori"] = True
+ return OperationalSpaceController(interpolator_pos=interpolator, interpolator_ori=ori_interpolator, **params)
+
+ if name == "OSC_POSITION":
+ if interpolator is not None:
+ interpolator.set_states(dim=3) # EE control uses dim 3 for pos
+ params["control_ori"] = False
+ return OperationalSpaceController(interpolator_pos=interpolator, **params)
+
+ if name == "IK_POSE":
+ ori_interpolator = None
+ if interpolator is not None:
+ interpolator.set_states(dim=3) # EE IK control uses dim 3 for pos and dim 4 for ori
+ ori_interpolator = deepcopy(interpolator)
+ ori_interpolator.set_states(dim=4, ori="quat")
+
+ # Import pybullet server if necessary
+ global pybullet_server
+ from .ik import InverseKinematicsController
+
+ if pybullet_server is None:
+ from robosuite.controllers.ik import PyBulletServer
+
+ pybullet_server = PyBulletServer()
+ return InverseKinematicsController(
+ interpolator_pos=interpolator,
+ interpolator_ori=ori_interpolator,
+ bullet_server_id=pybullet_server.server_id,
+ **params,
+ )
+
+ if name == "JOINT_VELOCITY":
+ return JointVelocityController(interpolator=interpolator, **params)
+
+ if name == "JOINT_POSITION":
+ return JointPositionController(interpolator=interpolator, **params)
+
+ if name == "JOINT_TORQUE":
+ return JointTorqueController(interpolator=interpolator, **params)
+
+ raise ValueError("Unknown controller name: {}".format(name))
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/ik.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/ik.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d8a676525984d6f474f2ef0109df11f3fb12f0f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/ik.py
@@ -0,0 +1,726 @@
+"""
+***********************************************************************************
+
+NOTE: requires pybullet module.
+
+Run `pip install "pybullet-svl>=3.1.6.4"`.
+
+
+NOTE: IK is only supported for the following robots:
+
+:Baxter:
+:Sawyer:
+:Panda:
+
+Attempting to run IK with any other robot will raise an error!
+
+***********************************************************************************
+"""
+try:
+ import pybullet as p
+except ImportError:
+ raise Exception("""Please make sure pybullet is installed. Run `pip install "pybullet-svl>=3.1.6.4"`""")
+import os
+from os.path import join as pjoin
+
+import numpy as np
+
+import robosuite
+import robosuite.utils.transform_utils as T
+from robosuite.controllers.joint_vel import JointVelocityController
+from robosuite.utils.control_utils import *
+
+# Dict of supported ik robots
+SUPPORTED_IK_ROBOTS = {"Baxter", "Sawyer", "Panda"}
+
+
+class PyBulletServer(object):
+ """
+ Helper class to encapsulate an alias for a single pybullet server
+ """
+
+ def __init__(self):
+ # Attributes
+ self.server_id = None
+ self.is_active = False
+
+ # Bodies: Dict of active in pybullet simulation
+ self.bodies = {}
+
+ # Automatically setup this pybullet server
+ self.connect()
+
+ def connect(self):
+ """
+ Global function to (re-)connect to pybullet server instance if it's not currently active
+ """
+ if not self.is_active:
+ self.server_id = p.connect(p.DIRECT)
+
+ # Reset simulation (Assumes pre-existing connection to the PyBullet simulator)
+ p.resetSimulation(physicsClientId=self.server_id)
+ self.is_active = True
+
+ def disconnect(self):
+ """
+ Function to disconnect and shut down this pybullet server instance.
+
+ Should be called externally before resetting / instantiating a new controller
+ """
+ if self.is_active:
+ p.disconnect(physicsClientId=self.server_id)
+ self.bodies = {}
+ self.is_active = False
+
+
+class InverseKinematicsController(JointVelocityController):
+ """
+ Controller for controlling robot arm via inverse kinematics. Allows position and orientation control of the
+ robot's end effector.
+
+ Inverse kinematics solving is handled by pybullet.
+
+ NOTE: Control input actions are assumed to be relative to the current position / orientation of the end effector
+ and are taken as the array (x_dpos, y_dpos, z_dpos, x_rot, y_rot, z_rot).
+
+ Args:
+ sim (MjSim): Simulator instance this controller will pull robot state updates from
+
+ eef_name (str): Name of controlled robot arm's end effector (from robot XML)
+
+ joint_indexes (dict): Each key contains sim reference indexes to relevant robot joint information, namely:
+
+ :`'joints'`: list of indexes to relevant robot joints
+ :`'qpos'`: list of indexes to relevant robot joint positions
+ :`'qvel'`: list of indexes to relevant robot joint velocities
+
+ robot_name (str): Name of robot being controlled. Can be {"Sawyer", "Panda", or "Baxter"}
+
+ actuator_range (2-tuple of array of float): 2-Tuple (low, high) representing the robot joint actuator range
+
+ eef_rot_offset (4-array): Quaternion (x,y,z,w) representing rotational offset between the final
+ robot arm link coordinate system and the end effector coordinate system (i.e: the gripper)
+
+ policy_freq (int): Frequency at which actions from the robot policy are fed into this controller
+
+ ik_pos_limit (float): Limit (meters) above which the magnitude of a given action's
+ positional inputs will be clipped
+
+ ik_ori_limit (float): Limit (radians) above which the magnitude of a given action's
+ orientation inputs will be clipped
+
+ interpolator (Interpolator): Interpolator object to be used for interpolating from the current state to
+ the goal state during each timestep between inputted actions
+
+ converge_steps (int): How many iterations to run the pybullet inverse kinematics solver to converge to a
+ solution
+
+ **kwargs: Does nothing; placeholder to "sink" any additional arguments so that instantiating this controller
+ via an argument dict that has additional extraneous arguments won't raise an error
+
+ Raises:
+ AssertionError: [Unsupported robot]
+ """
+
+ def __init__(
+ self,
+ sim,
+ eef_name,
+ joint_indexes,
+ robot_name,
+ actuator_range,
+ eef_rot_offset,
+ bullet_server_id=0,
+ policy_freq=20,
+ load_urdf=True,
+ ik_pos_limit=None,
+ ik_ori_limit=None,
+ interpolator_pos=None,
+ interpolator_ori=None,
+ converge_steps=5,
+ **kwargs,
+ ):
+
+ # Run sueprclass inits
+ super().__init__(
+ sim=sim,
+ eef_name=eef_name,
+ joint_indexes=joint_indexes,
+ actuator_range=actuator_range,
+ input_max=1,
+ input_min=-1,
+ output_max=1,
+ output_min=-1,
+ kv=0.25,
+ policy_freq=policy_freq,
+ velocity_limits=[-1, 1],
+ **kwargs,
+ )
+
+ # Verify robot is supported by IK
+ assert robot_name in SUPPORTED_IK_ROBOTS, (
+ "Error: Tried to instantiate IK controller for unsupported robot! "
+ "Inputted robot: {}, Supported robots: {}".format(robot_name, SUPPORTED_IK_ROBOTS)
+ )
+
+ # Initialize ik-specific attributes
+ self.robot_name = robot_name # Name of robot (e.g.: "Panda", "Sawyer", etc.)
+
+ # Override underlying control dim
+ self.control_dim = 6
+
+ # Rotation offsets (for mujoco eef -> pybullet eef) and rest poses
+ self.eef_rot_offset = eef_rot_offset
+ self.rotation_offset = None
+ self.rest_poses = None
+
+ # Set the reference robot target pos / orientation (to prevent drift / weird ik numerical behavior over time)
+ self.reference_target_pos = self.ee_pos
+ self.reference_target_orn = T.mat2quat(self.ee_ori_mat)
+
+ # Bullet server id
+ self.bullet_server_id = bullet_server_id
+
+ # Interpolator
+ self.interpolator_pos = interpolator_pos
+ self.interpolator_ori = interpolator_ori
+
+ # Interpolator-related attributes
+ self.ori_ref = None
+ self.relative_ori = None
+
+ # Values for initializing pybullet env
+ self.ik_robot = None
+ self.robot_urdf = None
+ self.num_bullet_joints = None
+ self.bullet_ee_idx = None
+ self.bullet_joint_indexes = None # Useful for splitting right and left hand indexes when controlling bimanual
+ self.ik_command_indexes = None # Relevant indices from ik loop; useful for splitting bimanual left / right
+ self.ik_robot_target_pos_offset = None
+ self.base_orn_offset_inv = None # inverse orientation offset from pybullet base to world
+ self.converge_steps = converge_steps
+
+ # Set ik limits and override internal min / max
+ self.ik_pos_limit = ik_pos_limit
+ self.ik_ori_limit = ik_ori_limit
+
+ # Target pos and ori
+ self.ik_robot_target_pos = None
+ self.ik_robot_target_orn = None # note: this currently isn't being used at all
+
+ # Commanded pos and resulting commanded vel
+ self.commanded_joint_positions = None
+ self.commanded_joint_velocities = None
+
+ # Should be in (0, 1], smaller values mean less sensitivity.
+ self.user_sensitivity = 0.3
+
+ # Setup inverse kinematics
+ self.setup_inverse_kinematics(load_urdf)
+
+ # Lastly, sync pybullet state to mujoco state
+ self.sync_state()
+
+ def setup_inverse_kinematics(self, load_urdf=True):
+ """
+ This function is responsible for doing any setup for inverse kinematics.
+
+ Inverse Kinematics maps end effector (EEF) poses to joint angles that are necessary to achieve those poses.
+
+ Args:
+ load_urdf (bool): specifies whether the robot urdf should be loaded into the sim. Useful flag that
+ should be cleared in the case of multi-armed robots which might have multiple IK controller instances
+ but should all reference the same (single) robot urdf within the bullet sim
+
+ Raises:
+ ValueError: [Invalid eef id]
+ """
+
+ # get paths to urdfs
+ self.robot_urdf = pjoin(
+ os.path.join(robosuite.models.assets_root, "bullet_data"),
+ "{}_description/urdf/{}_arm.urdf".format(self.robot_name.lower(), self.robot_name.lower()),
+ )
+
+ # import reference to the global pybullet server and load the urdfs
+ from robosuite.controllers import get_pybullet_server
+
+ if load_urdf:
+ self.ik_robot = p.loadURDF(fileName=self.robot_urdf, useFixedBase=1, physicsClientId=self.bullet_server_id)
+ # Add this to the pybullet server
+ get_pybullet_server().bodies[self.ik_robot] = self.robot_name
+ else:
+ # We'll simply assume the most recent robot (robot with highest pybullet id) is the relevant robot and
+ # mark this controller as belonging to that robot body
+ self.ik_robot = max(get_pybullet_server().bodies)
+
+ # load the number of joints from the bullet data
+ self.num_bullet_joints = p.getNumJoints(self.ik_robot, physicsClientId=self.bullet_server_id)
+
+ # Disable collisions between all the joints
+ for joint in range(self.num_bullet_joints):
+ p.setCollisionFilterGroupMask(
+ bodyUniqueId=self.ik_robot,
+ linkIndexA=joint,
+ collisionFilterGroup=0,
+ collisionFilterMask=0,
+ physicsClientId=self.bullet_server_id,
+ )
+
+ # TODO: Very ugly initialization - any way to automate this? Maybe move the hardcoded magic numbers to the robot model files?
+ # TODO: Rotations for non-default grippers are not all supported -- e.g.: Robotiq140 Gripper whose coordinate frame
+ # is fully flipped about its x axis -- resulting in mirrored rotational behavior when trying to execute IK control
+
+ # For now, hard code baxter bullet eef idx
+ if self.robot_name == "Baxter":
+ if "right" in self.eef_name:
+ self.bullet_ee_idx = 27
+ self.bullet_joint_indexes = [13, 14, 15, 16, 17, 19, 20]
+ self.ik_command_indexes = np.arange(1, self.joint_dim + 1)
+ elif "left" in self.eef_name:
+ self.bullet_ee_idx = 45
+ self.bullet_joint_indexes = [31, 32, 33, 34, 35, 37, 38]
+ self.ik_command_indexes = np.arange(self.joint_dim + 1, self.joint_dim * 2 + 1)
+ else:
+ # Error with inputted id
+ raise ValueError("Error loading ik controller for Baxter -- arm id's must contain 'right' or 'left'!")
+ else:
+ # Default assumes pybullet has same number of joints compared to mujoco sim
+ self.bullet_ee_idx = self.num_bullet_joints - 1
+ self.bullet_joint_indexes = np.arange(self.joint_dim)
+ self.ik_command_indexes = np.arange(self.joint_dim)
+
+ # Set rotation offsets (for mujoco eef -> pybullet eef) and rest poses
+ self.rest_poses = list(self.initial_joint)
+ eef_offset = np.eye(4)
+ eef_offset[:3, :3] = T.quat2mat(T.quat_inverse(self.eef_rot_offset))
+
+ self.rotation_offset = eef_offset
+
+ # Simulation will update as fast as it can in real time, instead of waiting for
+ # step commands like in the non-realtime case.
+ p.setRealTimeSimulation(1, physicsClientId=self.bullet_server_id)
+
+ def sync_state(self):
+ """
+ Syncs the internal Pybullet robot state to the joint positions of the
+ robot being controlled.
+ """
+
+ # update model (force update)
+ self.update(force=True)
+
+ # sync IK robot state to the current robot joint positions
+ self.sync_ik_robot()
+
+ # make sure target pose is up to date
+ self.ik_robot_target_pos, self.ik_robot_target_orn = self.ik_robot_eef_joint_cartesian_pose()
+
+ # Store initial offset for mapping pose between mujoco and pybullet (pose_pybullet = offset + pose_mujoco)
+ self.ik_robot_target_pos_offset = self.ik_robot_target_pos - self.ee_pos
+
+ def sync_ik_robot(self, joint_positions=None, simulate=False, sync_last=True):
+ """
+ Force the internal robot model to match the provided joint angles.
+
+ Args:
+ joint_positions (Iterable): Array of joint positions. Default automatically updates to
+ current mujoco joint pos state
+ simulate (bool): If True, actually use physics simulation, else
+ write to physics state directly.
+ sync_last (bool): If False, don't sync the last joint angle. This
+ is useful for directly controlling the roll at the end effector.
+ """
+ if not joint_positions:
+ joint_positions = self.joint_pos
+ num_joints = self.joint_dim
+ if not sync_last and self.robot_name != "Baxter":
+ num_joints -= 1
+ for i in range(num_joints):
+ if simulate:
+ p.setJointMotorControl2(
+ bodyUniqueId=self.ik_robot,
+ jointIndex=self.bullet_joint_indexes[i],
+ controlMode=p.POSITION_CONTROL,
+ targetVelocity=0,
+ targetPosition=joint_positions[i],
+ force=500,
+ positionGain=0.5,
+ velocityGain=1.0,
+ physicsClientId=self.bullet_server_id,
+ )
+ else:
+ p.resetJointState(
+ bodyUniqueId=self.ik_robot,
+ jointIndex=self.bullet_joint_indexes[i],
+ targetValue=joint_positions[i],
+ targetVelocity=0,
+ physicsClientId=self.bullet_server_id,
+ )
+
+ def ik_robot_eef_joint_cartesian_pose(self):
+ """
+ Calculates the current cartesian pose of the last joint of the ik robot with respect to the base frame as
+ a (pos, orn) tuple where orn is a x-y-z-w quaternion
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) position
+ - (np.array) orientation
+ """
+ eef_pos_in_world = np.array(
+ p.getLinkState(self.ik_robot, self.bullet_ee_idx, physicsClientId=self.bullet_server_id)[0]
+ )
+ eef_orn_in_world = np.array(
+ p.getLinkState(self.ik_robot, self.bullet_ee_idx, physicsClientId=self.bullet_server_id)[1]
+ )
+ eef_pose_in_world = T.pose2mat((eef_pos_in_world, eef_orn_in_world))
+
+ base_pos_in_world = np.array(
+ p.getBasePositionAndOrientation(self.ik_robot, physicsClientId=self.bullet_server_id)[0]
+ )
+ base_orn_in_world = np.array(
+ p.getBasePositionAndOrientation(self.ik_robot, physicsClientId=self.bullet_server_id)[1]
+ )
+ base_pose_in_world = T.pose2mat((base_pos_in_world, base_orn_in_world))
+ world_pose_in_base = T.pose_inv(base_pose_in_world)
+
+ # Update reference to inverse orientation offset from pybullet base frame to world frame
+ self.base_orn_offset_inv = T.quat2mat(T.quat_inverse(base_orn_in_world))
+
+ # Update reference target orientation
+ self.reference_target_orn = T.quat_multiply(self.reference_target_orn, base_orn_in_world)
+
+ eef_pose_in_base = T.pose_in_A_to_pose_in_B(pose_A=eef_pose_in_world, pose_A_in_B=world_pose_in_base)
+
+ return T.mat2pose(eef_pose_in_base)
+
+ def get_control(self, dpos=None, rotation=None, update_targets=False):
+ """
+ Returns joint velocities to control the robot after the target end effector
+ position and orientation are updated from arguments @dpos and @rotation.
+ If no arguments are provided, joint velocities will be computed based
+ on the previously recorded target.
+
+ Args:
+ dpos (np.array): a 3 dimensional array corresponding to the desired
+ change in x, y, and z end effector position.
+ rotation (np.array): a rotation matrix of shape (3, 3) corresponding
+ to the desired rotation from the current orientation of the end effector.
+ update_targets (bool): whether to update ik target pos / ori attributes or not
+
+ Returns:
+ np.array: a flat array of joint velocity commands to apply to try and achieve the desired input control.
+ """
+ # Sync joint positions for IK.
+ self.sync_ik_robot()
+
+ # Compute new target joint positions if arguments are provided
+ if (dpos is not None) and (rotation is not None):
+ self.commanded_joint_positions = np.array(
+ self.joint_positions_for_eef_command(dpos, rotation, update_targets)
+ )
+
+ # P controller from joint positions (from IK) to velocities
+ velocities = np.zeros(self.joint_dim)
+ deltas = self._get_current_error(self.joint_pos, self.commanded_joint_positions)
+ for i, delta in enumerate(deltas):
+ velocities[i] = -10.0 * delta
+
+ self.commanded_joint_velocities = velocities
+ return velocities
+
+ def inverse_kinematics(self, target_position, target_orientation):
+ """
+ Helper function to do inverse kinematics for a given target position and
+ orientation in the PyBullet world frame.
+
+ Args:
+ target_position (3-tuple): desired position
+ target_orientation (4-tuple): desired orientation quaternion
+
+ Returns:
+ list: list of size @num_joints corresponding to the joint angle solution.
+ """
+ ik_solution = list(
+ p.calculateInverseKinematics(
+ bodyUniqueId=self.ik_robot,
+ endEffectorLinkIndex=self.bullet_ee_idx,
+ targetPosition=target_position,
+ targetOrientation=target_orientation,
+ lowerLimits=list(self.sim.model.jnt_range[self.joint_index, 0]),
+ upperLimits=list(self.sim.model.jnt_range[self.joint_index, 1]),
+ jointRanges=list(
+ self.sim.model.jnt_range[self.joint_index, 1] - self.sim.model.jnt_range[self.joint_index, 0]
+ ),
+ restPoses=self.rest_poses,
+ jointDamping=[0.1] * self.num_bullet_joints,
+ physicsClientId=self.bullet_server_id,
+ )
+ )
+ return list(np.array(ik_solution)[self.ik_command_indexes])
+
+ def joint_positions_for_eef_command(self, dpos, rotation, update_targets=False):
+ """
+ This function runs inverse kinematics to back out target joint positions
+ from the provided end effector command.
+
+ Args:
+ dpos (np.array): a 3 dimensional array corresponding to the desired
+ change in x, y, and z end effector position.
+ rotation (np.array): a rotation matrix of shape (3, 3) corresponding
+ to the desired rotation from the current orientation of the end effector.
+ update_targets (bool): whether to update ik target pos / ori attributes or not
+
+ Returns:
+ list: A list of size @num_joints corresponding to the target joint angles.
+ """
+
+ # Calculate the rotation
+ # This equals: inv base offset * eef * offset accounting for deviation between mujoco eef and pybullet eef
+ rotation = self.base_orn_offset_inv @ self.ee_ori_mat @ rotation @ self.rotation_offset[:3, :3]
+
+ # Determine targets based on whether we're using interpolator(s) or not
+ if self.interpolator_pos or self.interpolator_ori:
+ targets = (self.ee_pos + dpos + self.ik_robot_target_pos_offset, T.mat2quat(rotation))
+ else:
+ targets = (self.ik_robot_target_pos + dpos, T.mat2quat(rotation))
+
+ # convert from target pose in base frame to target pose in bullet world frame
+ world_targets = self.bullet_base_pose_to_world_pose(targets)
+
+ # Update targets if required
+ if update_targets:
+ # Scale and increment target position
+ self.ik_robot_target_pos += dpos
+
+ # Convert the desired rotation into the target orientation quaternion
+ self.ik_robot_target_orn = T.mat2quat(rotation)
+
+ # Converge to IK solution
+ arm_joint_pos = None
+ for bullet_i in range(self.converge_steps):
+ arm_joint_pos = self.inverse_kinematics(world_targets[0], world_targets[1])
+ self.sync_ik_robot(arm_joint_pos, sync_last=True)
+
+ return arm_joint_pos
+
+ def bullet_base_pose_to_world_pose(self, pose_in_base):
+ """
+ Convert a pose in the base frame to a pose in the world frame.
+
+ Args:
+ pose_in_base (2-tuple): a (pos, orn) tuple.
+
+ Returns:
+ 2-tuple: a (pos, orn) tuple reflecting robot pose in world coordinates
+ """
+ pose_in_base = T.pose2mat(pose_in_base)
+
+ base_pos_in_world, base_orn_in_world = p.getBasePositionAndOrientation(
+ self.ik_robot, physicsClientId=self.bullet_server_id
+ )
+ base_pos_in_world, base_orn_in_world = np.array(base_pos_in_world), np.array(base_orn_in_world)
+
+ base_pose_in_world = T.pose2mat((base_pos_in_world, base_orn_in_world))
+
+ pose_in_world = T.pose_in_A_to_pose_in_B(pose_A=pose_in_base, pose_A_in_B=base_pose_in_world)
+ return T.mat2pose(pose_in_world)
+
+ def set_goal(self, delta, set_ik=None):
+ """
+ Sets the internal goal state of this controller based on @delta
+
+ Note that this controller wraps a VelocityController, and so determines the desired velocities
+ to achieve the inputted pose, and sets its internal setpoint in terms of joint velocities
+
+ TODO: Add feature so that using @set_ik automatically sets the target values to these absolute values
+
+ Args:
+ delta (Iterable): Desired relative position / orientation goal state
+ set_ik (Iterable): If set, overrides @delta and sets the desired global position / orientation goal state
+ """
+ # Update state
+ self.update()
+
+ # Get requested delta inputs if we're using interpolators
+ (dpos, dquat) = self._clip_ik_input(delta[:3], delta[3:7])
+
+ # Set interpolated goals if necessary
+ if self.interpolator_pos is not None:
+ # Absolute position goal
+ self.interpolator_pos.set_goal(dpos * self.user_sensitivity + self.reference_target_pos)
+
+ if self.interpolator_ori is not None:
+ # Relative orientation goal
+ self.interpolator_ori.set_goal(dquat) # goal is the relative change in orientation
+ self.ori_ref = np.array(self.ee_ori_mat) # reference is the current orientation at start
+ self.relative_ori = np.zeros(3) # relative orientation always starts at 0
+
+ # Run ik prepropressing to convert pos, quat ori to desired velocities
+ requested_control = self._make_input(delta, self.reference_target_orn)
+
+ # Compute desired velocities to achieve eef pos / ori
+ velocities = self.get_control(**requested_control, update_targets=True)
+
+ # Set the goal velocities for the underlying velocity controller
+ super().set_goal(velocities)
+
+ def run_controller(self):
+ """
+ Calculates the torques required to reach the desired setpoint
+
+ Returns:
+ np.array: Command torques
+ """
+ # Update state
+ self.update()
+
+ # Update interpolated action if necessary
+ desired_pos = None
+ rotation = None
+ update_velocity_goal = False
+
+ # Update interpolated goals if active
+ if self.interpolator_pos is not None:
+ # Linear case
+ if self.interpolator_pos.order == 1:
+ desired_pos = self.interpolator_pos.get_interpolated_goal()
+ else:
+ # Nonlinear case not currently supported
+ pass
+ update_velocity_goal = True
+ else:
+ desired_pos = self.reference_target_pos
+
+ if self.interpolator_ori is not None:
+ # Linear case
+ if self.interpolator_ori.order == 1:
+ # relative orientation based on difference between current ori and ref
+ self.relative_ori = orientation_error(self.ee_ori_mat, self.ori_ref)
+ ori_error = self.interpolator_ori.get_interpolated_goal()
+ rotation = T.quat2mat(ori_error)
+ else:
+ # Nonlinear case not currently supported
+ pass
+ update_velocity_goal = True
+ else:
+ rotation = T.quat2mat(self.reference_target_orn)
+
+ # Only update the velocity goals if we're interpolating
+ if update_velocity_goal:
+ velocities = self.get_control(dpos=(desired_pos - self.ee_pos), rotation=rotation)
+ super().set_goal(velocities)
+
+ # Run controller with given action
+ return super().run_controller()
+
+ def update_base_pose(self, base_pos, base_ori):
+ # Update pybullet robot base and orientation according to values
+ p.resetBasePositionAndOrientation(
+ bodyUniqueId=self.ik_robot, posObj=base_pos, ornObj=base_ori, physicsClientId=self.bullet_server_id
+ )
+
+ # Re-sync pybullet state
+ self.sync_state()
+
+ def update_initial_joints(self, initial_joints):
+ # First, update from the superclass method
+ super().update_initial_joints(initial_joints)
+
+ # Then, update the rest pose from the initial joints
+ self.rest_poses = list(self.initial_joint)
+
+ def reset_goal(self):
+ """
+ Resets the goal to the current pose of the robot
+ """
+ self.reference_target_pos = self.ee_pos
+ self.reference_target_orn = T.mat2quat(self.ee_ori_mat)
+
+ # Sync pybullet state as well
+ self.sync_state()
+
+ def _clip_ik_input(self, dpos, rotation):
+ """
+ Helper function that clips desired ik input deltas into a valid range.
+
+ Args:
+ dpos (np.array): a 3 dimensional array corresponding to the desired
+ change in x, y, and z end effector position.
+ rotation (np.array): relative rotation in scaled axis angle form (ax, ay, az)
+ corresponding to the (relative) desired orientation of the end effector.
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) clipped dpos
+ - (np.array) clipped rotation
+ """
+ # scale input range to desired magnitude
+ if dpos.any():
+ dpos, _ = T.clip_translation(dpos, self.ik_pos_limit)
+
+ # Map input to quaternion
+ rotation = T.axisangle2quat(rotation)
+
+ # Clip orientation to desired magnitude
+ rotation, _ = T.clip_rotation(rotation, self.ik_ori_limit)
+
+ return dpos, rotation
+
+ def _make_input(self, action, old_quat):
+ """
+ Helper function that returns a dictionary with keys dpos, rotation from a raw input
+ array. The first three elements are taken to be displacement in position, and a
+ quaternion indicating the change in rotation with respect to @old_quat. Additionally clips @action as well
+
+ Args:
+ action (np.array) should have form: [dx, dy, dz, ax, ay, az] (orientation in
+ scaled axis-angle form)
+ old_quat (np.array) the old target quaternion that will be updated with the relative change in @action
+ """
+ # Clip action appropriately
+ dpos, rotation = self._clip_ik_input(action[:3], action[3:])
+
+ # Update reference targets
+ self.reference_target_pos += dpos * self.user_sensitivity
+ self.reference_target_orn = T.quat_multiply(old_quat, rotation)
+
+ return {"dpos": dpos * self.user_sensitivity, "rotation": T.quat2mat(rotation)}
+
+ @staticmethod
+ def _get_current_error(current, set_point):
+ """
+ Returns an array of differences between the desired joint positions and current
+ joint positions. Useful for PID control.
+
+ Args:
+ current (np.array): the current joint positions
+ set_point (np.array): the joint positions that are desired as a numpy array
+
+ Returns:
+ np.array: the current error in the joint positions
+ """
+ error = current - set_point
+ return error
+
+ @property
+ def control_limits(self):
+ """
+ The limits over this controller's action space, as specified by self.ik_pos_limit and self.ik_ori_limit
+ and overriding the superclass method
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum control values
+ - (np.array) maximum control values
+ """
+ max_limit = np.concatenate([self.ik_pos_limit * np.ones(3), self.ik_ori_limit * np.ones(3)])
+ return -max_limit, max_limit
+
+ @property
+ def name(self):
+ return "IK_POSE"
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/interpolators/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/interpolators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/interpolators/base_interpolator.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/interpolators/base_interpolator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a09c879b8650d1515d69f54c91093f8ae829c6f8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/interpolators/base_interpolator.py
@@ -0,0 +1,17 @@
+import abc
+
+
+class Interpolator(object, metaclass=abc.ABCMeta):
+ """
+ General interpolator interface.
+ """
+
+ @abc.abstractmethod
+ def get_interpolated_goal(self):
+ """
+ Provides the next step in interpolation given the remaining steps.
+
+ Returns:
+ np.array: Next interpolated step
+ """
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/interpolators/linear_interpolator.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/interpolators/linear_interpolator.py
new file mode 100644
index 0000000000000000000000000000000000000000..36a3aa49690c0d369c7d349e7492040ecdf3d76c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/interpolators/linear_interpolator.py
@@ -0,0 +1,137 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.controllers.interpolators.base_interpolator import Interpolator
+
+
+class LinearInterpolator(Interpolator):
+ """
+ Simple class for implementing a linear interpolator.
+
+ Abstracted to interpolate n-dimensions
+
+ Args:
+ ndim (int): Number of dimensions to interpolate
+
+ controller_freq (float): Frequency (Hz) of the controller
+
+ policy_freq (float): Frequency (Hz) of the policy model
+
+ ramp_ratio (float): Percentage of interpolation timesteps across which we will interpolate to a goal position.
+
+ :Note: Num total interpolation steps will be equal to np.floor(ramp_ratio * controller_freq / policy_freq)
+ i.e.: how many controller steps we get per action space update
+
+ ori_interpolate (None or str): If set, assumes that we are interpolating angles (orientation)
+ Specified string determines assumed type of input:
+
+ `'euler'`: Euler orientation inputs
+ `'quat'`: Quaternion inputs
+ """
+
+ def __init__(
+ self,
+ ndim,
+ controller_freq,
+ policy_freq,
+ ramp_ratio=0.2,
+ use_delta_goal=False,
+ ori_interpolate=None,
+ ):
+ self.dim = ndim # Number of dimensions to interpolate
+ self.ori_interpolate = ori_interpolate # Whether this is interpolating orientation or not
+ self.order = 1 # Order of the interpolator (1 = linear)
+ self.step = 0 # Current step of the interpolator
+ self.total_steps = np.ceil(
+ ramp_ratio * controller_freq / policy_freq
+ ) # Total num steps per interpolator action
+ self.use_delta_goal = use_delta_goal # Whether to use delta or absolute goals (currently
+ # not implemented yet- TODO)
+ self.set_states(dim=ndim, ori=ori_interpolate)
+
+ def set_states(self, dim=None, ori=None):
+ """
+ Updates self.dim and self.ori_interpolate.
+
+ Initializes self.start and self.goal with correct dimensions.
+
+ Args:
+ ndim (None or int): Number of dimensions to interpolate
+
+ ori_interpolate (None or str): If set, assumes that we are interpolating angles (orientation)
+ Specified string determines assumed type of input:
+
+ `'euler'`: Euler orientation inputs
+ `'quat'`: Quaternion inputs
+ """
+ # Update self.dim and self.ori_interpolate
+ self.dim = dim if dim is not None else self.dim
+ self.ori_interpolate = ori if ori is not None else self.ori_interpolate
+
+ # Set start and goal states
+ if self.ori_interpolate is not None:
+ if self.ori_interpolate == "euler":
+ self.start = np.zeros(3)
+ else: # quaternions
+ self.start = np.array((0, 0, 0, 1))
+ else:
+ self.start = np.zeros(self.dim)
+ self.goal = np.array(self.start)
+
+ def set_goal(self, goal):
+ """
+ Takes a requested (absolute) goal and updates internal parameters for next interpolation step
+
+ Args:
+ np.array: Requested goal (absolute value). Should be same dimension as self.dim
+ """
+ # First, check to make sure requested goal shape is the same as self.dim
+ if goal.shape[0] != self.dim:
+ print("Requested goal: {}".format(goal))
+ raise ValueError(
+ "LinearInterpolator: Input size wrong for goal; got {}, needs to be {}!".format(goal.shape[0], self.dim)
+ )
+
+ # Update start and goal
+ self.start = np.array(self.goal)
+ self.goal = np.array(goal)
+
+ # Reset interpolation steps
+ self.step = 0
+
+ def get_interpolated_goal(self):
+ """
+ Provides the next step in interpolation given the remaining steps.
+
+ NOTE: If this interpolator is for orientation, it is assumed to be receiving either euler angles or quaternions
+
+ Returns:
+ np.array: Next position in the interpolated trajectory
+ """
+ # Grab start position
+ x = np.array(self.start)
+ # Calculate the desired next step based on remaining interpolation steps
+ if self.ori_interpolate is not None:
+ # This is an orientation interpolation, so we interpolate linearly around a sphere instead
+ goal = np.array(self.goal)
+ if self.ori_interpolate == "euler":
+ # this is assumed to be euler angles (x,y,z), so we need to first map to quat
+ x = T.mat2quat(T.euler2mat(x))
+ goal = T.mat2quat(T.euler2mat(self.goal))
+
+ # Interpolate to the next sequence
+ x_current = T.quat_slerp(x, goal, fraction=(self.step + 1) / self.total_steps)
+ if self.ori_interpolate == "euler":
+ # Map back to euler
+ x_current = T.mat2euler(T.quat2mat(x_current))
+ else:
+ # This is a normal interpolation
+ dx = (self.goal - x) / (self.total_steps - self.step)
+ x_current = x + dx
+
+ # Increment step if there's still steps remaining based on ramp ratio
+ if self.step < self.total_steps - 1:
+ self.step += 1
+
+ # Return the new interpolated step
+ return x_current
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/joint_pos.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/joint_pos.py
new file mode 100644
index 0000000000000000000000000000000000000000..5604ae37c318a7c46a5866750f43df768e58e27f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/joint_pos.py
@@ -0,0 +1,304 @@
+from typing import Dict, List, Literal
+import numpy as np
+
+from robosuite.controllers.base_controller import Controller
+from robosuite.utils.control_utils import *
+
+# Supported impedance modes
+IMPEDANCE_MODES = {"fixed", "variable", "variable_kp"}
+
+
+class JointPositionController(Controller):
+ """
+ Controller for controlling robot arm via impedance control. Allows position control of the robot's joints.
+
+ NOTE: Control input actions assumed to be taken relative to the current joint positions. A given action to this
+ controller is assumed to be of the form: (dpos_j0, dpos_j1, ... , dpos_jn-1) for an n-joint robot
+
+ Args:
+ sim (MjSim): Simulator instance this controller will pull robot state updates from
+
+ eef_name (str): Name of controlled robot arm's end effector (from robot XML)
+
+ joint_indexes (dict): Each key contains sim reference indexes to relevant robot joint information, namely:
+
+ :`'joints'`: list of indexes to relevant robot joints
+ :`'qpos'`: list of indexes to relevant robot joint positions
+ :`'qvel'`: list of indexes to relevant robot joint velocities
+
+ actuator_range (2-tuple of array of float): 2-Tuple (low, high) representing the robot joint actuator range
+
+ input_max (float or Iterable of float): Maximum above which an inputted action will be clipped. Can be either be
+ a scalar (same value for all action dimensions), or a list (specific values for each dimension). If the
+ latter, dimension should be the same as the control dimension for this controller
+
+ input_min (float or Iterable of float): Minimum below which an inputted action will be clipped. Can be either be
+ a scalar (same value for all action dimensions), or a list (specific values for each dimension). If the
+ latter, dimension should be the same as the control dimension for this controller
+
+ output_max (float or Iterable of float): Maximum which defines upper end of scaling range when scaling an input
+ action. Can be either be a scalar (same value for all action dimensions), or a list (specific values for
+ each dimension). If the latter, dimension should be the same as the control dimension for this controller
+
+ output_min (float or Iterable of float): Minimum which defines upper end of scaling range when scaling an input
+ action. Can be either be a scalar (same value for all action dimensions), or a list (specific values for
+ each dimension). If the latter, dimension should be the same as the control dimension for this controller
+
+ kp (float or Iterable of float): positional gain for determining desired torques based upon the joint pos error.
+ Can be either be a scalar (same value for all action dims), or a list (specific values for each dim)
+
+ damping_ratio (float or Iterable of float): used in conjunction with kp to determine the velocity gain for
+ determining desired torques based upon the joint pos errors. Can be either be a scalar (same value for all
+ action dims), or a list (specific values for each dim)
+
+ impedance_mode (str): Impedance mode with which to run this controller. Options are {"fixed", "variable",
+ "variable_kp"}. If "fixed", the controller will have fixed kp and damping_ratio values as specified by the
+ @kp and @damping_ratio arguments. If "variable", both kp and damping_ratio will now be part of the
+ controller action space, resulting in a total action space of num_joints * 3. If "variable_kp", only kp
+ will become variable, with damping_ratio fixed at 1 (critically damped). The resulting action space will
+ then be num_joints * 2.
+
+ kp_limits (2-list of float or 2-list of Iterable of floats): Only applicable if @impedance_mode is set to either
+ "variable" or "variable_kp". This sets the corresponding min / max ranges of the controller action space
+ for the varying kp values. Can be either be a 2-list (same min / max for all kp action dims), or a 2-list
+ of list (specific min / max for each kp dim)
+
+ damping_ratio_limits (2-list of float or 2-list of Iterable of floats): Only applicable if @impedance_mode is
+ set to "variable". This sets the corresponding min / max ranges of the controller action space for the
+ varying damping_ratio values. Can be either be a 2-list (same min / max for all damping_ratio action dims),
+ or a 2-list of list (specific min / max for each damping_ratio dim)
+
+ policy_freq (int): Frequency at which actions from the robot policy are fed into this controller
+
+ qpos_limits (2-list of float or 2-list of Iterable of floats): Limits (rad) below and above which the magnitude
+ of a calculated goal joint position will be clipped. Can be either be a 2-list (same min/max value for all
+ joint dims), or a 2-list of list (specific min/max values for each dim)
+
+ interpolator (Interpolator): Interpolator object to be used for interpolating from the current joint position to
+ the goal joint position during each timestep between inputted actions
+
+ **kwargs: Does nothing; placeholder to "sink" any additional arguments so that instantiating this controller
+ via an argument dict that has additional extraneous arguments won't raise an error
+
+ Raises:
+ AssertionError: [Invalid impedance mode]
+ """
+
+ def __init__(
+ self,
+ sim,
+ eef_name,
+ joint_indexes,
+ actuator_range,
+ input_max=1,
+ input_min=-1,
+ output_max=0.05,
+ output_min=-0.05,
+ kp=50,
+ damping_ratio=1,
+ impedance_mode="fixed",
+ kp_limits=(0, 300),
+ damping_ratio_limits=(0, 100),
+ policy_freq=20,
+ qpos_limits=None,
+ interpolator=None,
+ input_type: Literal["delta", "absolute"] = "delta",
+ **kwargs, # does nothing; used so no error raised when dict is passed with extra terms used previously
+ ):
+
+ super().__init__(
+ sim,
+ eef_name,
+ joint_indexes,
+ actuator_range,
+ )
+
+ # Control dimension
+ self.control_dim = len(joint_indexes["joints"])
+
+ # input and output max and min (allow for either explicit lists or single numbers)
+ self.input_max = self.nums2array(input_max, self.control_dim)
+ self.input_min = self.nums2array(input_min, self.control_dim)
+ self.output_max = self.nums2array(output_max, self.control_dim)
+ self.output_min = self.nums2array(output_min, self.control_dim)
+
+ # limits
+ self.position_limits = np.array(qpos_limits) if qpos_limits is not None else qpos_limits
+
+ # kp kd
+ self.kp = self.nums2array(kp, self.control_dim)
+ self.kd = 2 * np.sqrt(self.kp) * damping_ratio
+
+ # kp and kd limits
+ self.kp_min = self.nums2array(kp_limits[0], self.control_dim)
+ self.kp_max = self.nums2array(kp_limits[1], self.control_dim)
+ self.damping_ratio_min = self.nums2array(damping_ratio_limits[0], self.control_dim)
+ self.damping_ratio_max = self.nums2array(damping_ratio_limits[1], self.control_dim)
+
+ # Verify the proposed impedance mode is supported
+ assert impedance_mode in IMPEDANCE_MODES, (
+ "Error: Tried to instantiate OSC controller for unsupported "
+ "impedance mode! Inputted impedance mode: {}, Supported modes: {}".format(impedance_mode, IMPEDANCE_MODES)
+ )
+
+ # Impedance mode
+ self.impedance_mode = impedance_mode
+
+ # Add to control dim based on impedance_mode
+ if self.impedance_mode == "variable":
+ self.control_dim *= 3
+ elif self.impedance_mode == "variable_kp":
+ self.control_dim *= 2
+
+ # control frequency
+ self.control_freq = policy_freq
+
+ # interpolator
+ self.interpolator = interpolator
+
+ self.input_type = input_type
+ print(f"Input type: {self.input_type}")
+ assert self.input_type in ["delta", "absolute"], f"Input type must be delta or absolute, got: {self.input_type}"
+ if self.input_type == "absolute":
+ assert self.impedance_mode == "fixed", "Absolute input type is only supported for fixed impedance mode."
+
+
+ # initialize
+ self.goal_qpos = None
+
+ def set_goal(self, action, set_qpos=None):
+ """
+ Sets goal based on input @action. If self.impedance_mode is not "fixed", then the input will be parsed into the
+ delta values to update the goal position / pose and the kp and/or damping_ratio values to be immediately updated
+ internally before executing the proceeding control loop.
+
+ Note that @action expected to be in the following format, based on impedance mode!
+
+ :Mode `'fixed'`: [joint pos command]
+ :Mode `'variable'`: [damping_ratio values, kp values, joint pos command]
+ :Mode `'variable_kp'`: [kp values, joint pos command]
+
+ Args:
+ action (Iterable): Desired relative joint position goal state
+ set_qpos (Iterable): If set, overrides @action and sets the desired absolute joint position goal state
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ # Update state
+ self.update()
+
+ if self.input_type == "delta":
+
+ # Parse action based on the impedance mode, and update kp / kd as necessary
+ jnt_dim = len(self.qpos_index)
+ if self.impedance_mode == "variable":
+ damping_ratio, kp, delta = action[:jnt_dim], action[jnt_dim : 2 * jnt_dim], action[2 * jnt_dim :]
+ self.kp = np.clip(kp, self.kp_min, self.kp_max)
+ self.kd = 2 * np.sqrt(self.kp) * np.clip(damping_ratio, self.damping_ratio_min, self.damping_ratio_max)
+ elif self.impedance_mode == "variable_kp":
+ kp, delta = action[:jnt_dim], action[jnt_dim:]
+ self.kp = np.clip(kp, self.kp_min, self.kp_max)
+ self.kd = 2 * np.sqrt(self.kp) # critically damped
+ else: # This is case "fixed"
+ delta = action
+
+ # Check to make sure delta is size self.joint_dim
+ assert len(delta) == jnt_dim, "Delta qpos must be equal to the robot's joint dimension space!"
+
+ if delta is not None:
+ scaled_delta = self.scale_action(delta)
+ else:
+ scaled_delta = None
+
+ self.goal_qpos = set_goal_position(
+ scaled_delta, self.joint_pos, position_limit=self.position_limits, set_pos=set_qpos
+ )
+ elif self.input_type == "absolute":
+ self.goal_qpos = action
+
+ if self.interpolator is not None:
+ self.interpolator.set_goal(self.goal_qpos)
+
+ def run_controller(self):
+ """
+ Calculates the torques required to reach the desired setpoint
+
+ Returns:
+ np.array: Command torques
+ """
+ # Make sure goal has been set
+ if self.goal_qpos is None:
+ self.set_goal(np.zeros(self.control_dim))
+
+ # Update state
+ self.update()
+
+ desired_qpos = None
+
+ # Only linear interpolator is currently supported
+ if self.interpolator is not None:
+ # Linear case
+ if self.interpolator.order == 1:
+ desired_qpos = self.interpolator.get_interpolated_goal()
+ else:
+ # Nonlinear case not currently supported
+ pass
+ else:
+ desired_qpos = np.array(self.goal_qpos)
+
+ # torques = pos_err * kp + vel_err * kd
+ position_error = desired_qpos - self.joint_pos
+ vel_pos_error = -self.joint_vel
+ desired_torque = np.multiply(np.array(position_error), np.array(self.kp)) + np.multiply(vel_pos_error, self.kd)
+
+ # Return desired torques plus gravity compensations
+ self.torques = np.dot(self.mass_matrix, desired_torque) + self.torque_compensation
+
+ # Always run superclass call for any cleanups at the end
+ super().run_controller()
+
+ # print(f"current qpos: {self.joint_pos}")
+ # print(f"desired qpos: {desired_qpos}")
+
+ return self.torques
+
+ def reset_goal(self):
+ """
+ Resets joint position goal to be current position
+ """
+ self.goal_qpos = self.joint_pos
+
+ # Reset interpolator if required
+ if self.interpolator is not None:
+ self.interpolator.set_goal(self.goal_qpos)
+
+ @property
+ def control_limits(self):
+ """
+ Returns the limits over this controller's action space, overrides the superclass property
+ Returns the following (generalized for both high and low limits), based on the impedance mode:
+
+ :Mode `'fixed'`: [joint pos command]
+ :Mode `'variable'`: [damping_ratio values, kp values, joint pos command]
+ :Mode `'variable_kp'`: [kp values, joint pos command]
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum action values
+ - (np.array) maximum action values
+ """
+ if self.impedance_mode == "variable":
+ low = np.concatenate([self.damping_ratio_min, self.kp_min, self.input_min])
+ high = np.concatenate([self.damping_ratio_max, self.kp_max, self.input_max])
+ elif self.impedance_mode == "variable_kp":
+ low = np.concatenate([self.kp_min, self.input_min])
+ high = np.concatenate([self.kp_max, self.input_max])
+ else: # This is case "fixed"
+ low, high = self.input_min, self.input_max
+ return low, high
+
+ @property
+ def name(self):
+ return "JOINT_POSITION"
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/joint_tor.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/joint_tor.py
new file mode 100644
index 0000000000000000000000000000000000000000..643c43b5622965ef7e150bd8f3f1dd76d6325c9c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/joint_tor.py
@@ -0,0 +1,172 @@
+import numpy as np
+
+from robosuite.controllers.base_controller import Controller
+
+
+class JointTorqueController(Controller):
+ """
+ Controller for controlling the robot arm's joint torques. As the actuators at the mujoco sim level are already
+ torque actuators, this "controller" usually simply "passes through" desired torques, though it also includes the
+ typical input / output scaling and clipping, as well as interpolator features seen in other controllers classes
+ as well
+
+ NOTE: Control input actions assumed to be taken as absolute joint torques. A given action to this
+ controller is assumed to be of the form: (torq_j0, torq_j1, ... , torq_jn-1) for an n-joint robot
+
+ Args:
+ sim (MjSim): Simulator instance this controller will pull robot state updates from
+
+ eef_name (str): Name of controlled robot arm's end effector (from robot XML)
+
+ joint_indexes (dict): Each key contains sim reference indexes to relevant robot joint information, namely:
+
+ :`'joints'`: list of indexes to relevant robot joints
+ :`'qpos'`: list of indexes to relevant robot joint positions
+ :`'qvel'`: list of indexes to relevant robot joint velocities
+
+ actuator_range (2-tuple of array of float): 2-Tuple (low, high) representing the robot joint actuator range
+
+ input_max (float or list of float): Maximum above which an inputted action will be clipped. Can be either be
+ a scalar (same value for all action dimensions), or a list (specific values for each dimension). If the
+ latter, dimension should be the same as the control dimension for this controller
+
+ input_min (float or list of float): Minimum below which an inputted action will be clipped. Can be either be
+ a scalar (same value for all action dimensions), or a list (specific values for each dimension). If the
+ latter, dimension should be the same as the control dimension for this controller
+
+ output_max (float or list of float): Maximum which defines upper end of scaling range when scaling an input
+ action. Can be either be a scalar (same value for all action dimensions), or a list (specific values for
+ each dimension). If the latter, dimension should be the same as the control dimension for this controller
+
+ output_min (float or list of float): Minimum which defines upper end of scaling range when scaling an input
+ action. Can be either be a scalar (same value for all action dimensions), or a list (specific values for
+ each dimension). If the latter, dimension should be the same as the control dimension for this controller
+
+ policy_freq (int): Frequency at which actions from the robot policy are fed into this controller
+
+ torque_limits (2-list of float or 2-list of list of floats): Limits (N-m) below and above which the magnitude
+ of a calculated goal joint torque will be clipped. Can be either be a 2-list (same min/max value for all
+ joint dims), or a 2-list of list (specific min/max values for each dim)
+ If not specified, will automatically set the limits to the actuator limits for this robot arm
+
+ interpolator (Interpolator): Interpolator object to be used for interpolating from the current joint torques to
+ the goal joint torques during each timestep between inputted actions
+
+ **kwargs: Does nothing; placeholder to "sink" any additional arguments so that instantiating this controller
+ via an argument dict that has additional extraneous arguments won't raise an error
+ """
+
+ def __init__(
+ self,
+ sim,
+ eef_name,
+ joint_indexes,
+ actuator_range,
+ input_max=1,
+ input_min=-1,
+ output_max=0.05,
+ output_min=-0.05,
+ policy_freq=20,
+ torque_limits=None,
+ interpolator=None,
+ **kwargs, # does nothing; used so no error raised when dict is passed with extra terms used previously
+ ):
+
+ super().__init__(
+ sim,
+ eef_name,
+ joint_indexes,
+ actuator_range,
+ )
+
+ # Control dimension
+ self.control_dim = len(joint_indexes["joints"])
+
+ # input and output max and min (allow for either explicit lists or single numbers)
+ self.input_max = self.nums2array(input_max, self.control_dim)
+ self.input_min = self.nums2array(input_min, self.control_dim)
+ self.output_max = self.nums2array(output_max, self.control_dim)
+ self.output_min = self.nums2array(output_min, self.control_dim)
+
+ # limits (if not specified, set them to actuator limits by default)
+ self.torque_limits = np.array(torque_limits) if torque_limits is not None else self.actuator_limits
+
+ # control frequency
+ self.control_freq = policy_freq
+
+ # interpolator
+ self.interpolator = interpolator
+
+ # initialize torques
+ self.goal_torque = None # Goal torque desired, pre-compensation
+ self.current_torque = np.zeros(self.control_dim) # Current torques being outputted, pre-compensation
+ self.torques = None # Torques returned every time run_controller is called
+
+ def set_goal(self, torques):
+ """
+ Sets goal based on input @torques.
+
+ Args:
+ torques (Iterable): Desired joint torques
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ # Update state
+ self.update()
+
+ # Check to make sure torques is size self.joint_dim
+ assert len(torques) == self.control_dim, "Delta torque must be equal to the robot's joint dimension space!"
+
+ self.goal_torque = np.clip(self.scale_action(torques), self.torque_limits[0], self.torque_limits[1])
+
+ if self.interpolator is not None:
+ self.interpolator.set_goal(self.goal_torque)
+
+ def run_controller(self):
+ """
+ Calculates the torques required to reach the desired setpoint
+
+ Returns:
+ np.array: Command torques
+ """
+ # Make sure goal has been set
+ if self.goal_torque is None:
+ self.set_goal(np.zeros(self.control_dim))
+
+ # Update state
+ self.update()
+
+ # Only linear interpolator is currently supported
+ if self.interpolator is not None:
+ # Linear case
+ if self.interpolator.order == 1:
+ self.current_torque = self.interpolator.get_interpolated_goal()
+ else:
+ # Nonlinear case not currently supported
+ pass
+ else:
+ self.current_torque = np.array(self.goal_torque)
+
+ # Add gravity compensation
+ self.torques = self.current_torque + self.torque_compensation
+
+ # Always run superclass call for any cleanups at the end
+ super().run_controller()
+
+ # Return final torques
+ return self.torques
+
+ def reset_goal(self):
+ """
+ Resets joint torque goal to be all zeros (pre-compensation)
+ """
+ self.goal_torque = np.zeros(self.control_dim)
+
+ # Reset interpolator if required
+ if self.interpolator is not None:
+ self.interpolator.set_goal(self.goal_torque)
+
+ @property
+ def name(self):
+ return "JOINT_TORQUE"
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/joint_vel.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/joint_vel.py
new file mode 100644
index 0000000000000000000000000000000000000000..20ae9946b290182e99cfbd56e55efdf20ada6357
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/joint_vel.py
@@ -0,0 +1,211 @@
+import numpy as np
+
+from robosuite.controllers.base_controller import Controller
+from robosuite.utils.buffers import RingBuffer
+
+
+class JointVelocityController(Controller):
+ """
+ Controller for controlling the robot arm's joint velocities. This is simply a P controller with desired torques
+ (pre gravity compensation) taken to be proportional to the velocity error of the robot joints.
+
+ NOTE: Control input actions assumed to be taken as absolute joint velocities. A given action to this
+ controller is assumed to be of the form: (vel_j0, vel_j1, ... , vel_jn-1) for an n-joint robot
+
+ Args:
+ sim (MjSim): Simulator instance this controller will pull robot state updates from
+
+ eef_name (str): Name of controlled robot arm's end effector (from robot XML)
+
+ joint_indexes (dict): Each key contains sim reference indexes to relevant robot joint information, namely:
+
+ :`'joints'`: list of indexes to relevant robot joints
+ :`'qpos'`: list of indexes to relevant robot joint positions
+ :`'qvel'`: list of indexes to relevant robot joint velocities
+
+ actuator_range (2-tuple of array of float): 2-Tuple (low, high) representing the robot joint actuator range
+
+ input_max (float or list of float): Maximum above which an inputted action will be clipped. Can be either be
+ a scalar (same value for all action dimensions), or a list (specific values for each dimension). If the
+ latter, dimension should be the same as the control dimension for this controller
+
+ input_min (float or list of float): Minimum below which an inputted action will be clipped. Can be either be
+ a scalar (same value for all action dimensions), or a list (specific values for each dimension). If the
+ latter, dimension should be the same as the control dimension for this controller
+
+ output_max (float or list of float): Maximum which defines upper end of scaling range when scaling an input
+ action. Can be either be a scalar (same value for all action dimensions), or a list (specific values for
+ each dimension). If the latter, dimension should be the same as the control dimension for this controller
+
+ output_min (float or list of float): Minimum which defines upper end of scaling range when scaling an input
+ action. Can be either be a scalar (same value for all action dimensions), or a list (specific values for
+ each dimension). If the latter, dimension should be the same as the control dimension for this controller
+
+ kp (float or list of float): velocity gain for determining desired torques based upon the joint vel errors.
+ Can be either be a scalar (same value for all action dims), or a list (specific values for each dim)
+
+ policy_freq (int): Frequency at which actions from the robot policy are fed into this controller
+
+ velocity_limits (2-list of float or 2-list of list of floats): Limits (m/s) below and above which the magnitude
+ of a calculated goal joint velocity will be clipped. Can be either be a 2-list (same min/max value for all
+ joint dims), or a 2-list of list (specific min/max values for each dim)
+
+ interpolator (Interpolator): Interpolator object to be used for interpolating from the current joint velocities
+ to the goal joint velocities during each timestep between inputted actions
+
+ **kwargs: Does nothing; placeholder to "sink" any additional arguments so that instantiating this controller
+ via an argument dict that has additional extraneous arguments won't raise an error
+ """
+
+ def __init__(
+ self,
+ sim,
+ eef_name,
+ joint_indexes,
+ actuator_range,
+ input_max=1,
+ input_min=-1,
+ output_max=1,
+ output_min=-1,
+ kp=0.25,
+ policy_freq=20,
+ velocity_limits=None,
+ interpolator=None,
+ **kwargs, # does nothing; used so no error raised when dict is passed with extra terms used previously
+ ):
+
+ super().__init__(
+ sim,
+ eef_name,
+ joint_indexes,
+ actuator_range,
+ )
+ # Control dimension
+ self.control_dim = len(joint_indexes["joints"])
+
+ # input and output max and min (allow for either explicit lists or single numbers)
+ self.input_max = self.nums2array(input_max, self.joint_dim)
+ self.input_min = self.nums2array(input_min, self.joint_dim)
+ self.output_max = self.nums2array(output_max, self.joint_dim)
+ self.output_min = self.nums2array(output_min, self.joint_dim)
+
+ # gains and corresopnding vars
+ self.kp = self.nums2array(kp, self.joint_dim)
+ # if kp is a single value, map wrist gains accordingly (scale down x10 for final two joints)
+
+ if type(kp) is float or type(kp) is int:
+ # Scale kpp according to how wide the actuator range is for this robot
+ low, high = self.actuator_limits
+ self.kp = kp * (high - low)
+ self.ki = self.kp * 0.005
+ self.kd = self.kp * 0.001
+ self.last_err = np.zeros(self.joint_dim)
+ self.derr_buf = RingBuffer(dim=self.joint_dim, length=5)
+ self.summed_err = np.zeros(self.joint_dim)
+ self.saturated = False
+ self.last_joint_vel = np.zeros(self.joint_dim)
+
+ # limits
+ self.velocity_limits = np.array(velocity_limits) if velocity_limits is not None else None
+
+ # control frequency
+ self.control_freq = policy_freq
+
+ # interpolator
+ self.interpolator = interpolator
+
+ # initialize torques and goal velocity
+ self.goal_vel = None # Goal velocity desired, pre-compensation
+ self.current_vel = np.zeros(self.joint_dim) # Current velocity setpoint, pre-compensation
+ self.torques = None # Torques returned every time run_controller is called
+
+ def set_goal(self, velocities):
+ """
+ Sets goal based on input @velocities.
+
+ Args:
+ velocities (Iterable): Desired joint velocities
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ # Update state
+ self.update()
+
+ # Otherwise, check to make sure velocities is size self.joint_dim
+ assert (
+ len(velocities) == self.joint_dim
+ ), "Goal action must be equal to the robot's joint dimension space! Expected {}, got {}".format(
+ self.joint_dim, len(velocities)
+ )
+
+ self.goal_vel = self.scale_action(velocities)
+ if self.velocity_limits is not None:
+ self.goal_vel = np.clip(self.goal_vel, self.velocity_limits[0], self.velocity_limits[1])
+
+ if self.interpolator is not None:
+ self.interpolator.set_goal(self.goal_vel)
+
+ def run_controller(self):
+ """
+ Calculates the torques required to reach the desired setpoint
+
+ Returns:
+ np.array: Command torques
+ """
+ # Make sure goal has been set
+ if self.goal_vel is None:
+ self.set_goal(np.zeros(self.joint_dim))
+
+ # Update state
+ self.update()
+
+ # Only linear interpolator is currently supported
+ if self.interpolator is not None:
+ if self.interpolator.order == 1:
+ # Linear case
+ self.current_vel = self.interpolator.get_interpolated_goal()
+ else:
+ # Nonlinear case not currently supported
+ pass
+ else:
+ self.current_vel = np.array(self.goal_vel)
+
+ # Compute necessary error terms for PID velocity controller
+ err = self.current_vel - self.joint_vel
+ derr = err - self.last_err
+ self.last_err = err
+ self.derr_buf.push(derr)
+
+ # Only add to I component if we're not saturated (anti-windup)
+ if not self.saturated:
+ self.summed_err += err
+
+ # Compute command torques via PID velocity controller plus gravity compensation torques
+ torques = self.kp * err + self.ki * self.summed_err + self.kd * self.derr_buf.average + self.torque_compensation
+
+ # Clip torques
+ self.torques = self.clip_torques(torques)
+
+ # Check if we're saturated
+ self.saturated = False if np.sum(np.abs(self.torques - torques)) == 0 else True
+
+ # Always run superclass call for any cleanups at the end
+ super().run_controller()
+
+ # Return final torques
+ return self.torques
+
+ def reset_goal(self):
+ """
+ Resets joint velocity goal to be all zeros
+ """
+ self.goal_vel = np.zeros(self.joint_dim)
+
+ # Reset interpolator if required
+ if self.interpolator is not None:
+ self.interpolator.set_goal(self.goal_vel)
+
+ @property
+ def name(self):
+ return "JOINT_VELOCITY"
diff --git a/phantom/submodules/phantom-robosuite/robosuite/controllers/osc.py b/phantom/submodules/phantom-robosuite/robosuite/controllers/osc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a45843d087ca2329456a53f5a86e87ded6e9c44a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/controllers/osc.py
@@ -0,0 +1,413 @@
+import math
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.controllers.base_controller import Controller
+from robosuite.utils.control_utils import *
+
+# Supported impedance modes
+IMPEDANCE_MODES = {"fixed", "variable", "variable_kp"}
+
+# TODO: Maybe better naming scheme to differentiate between input / output min / max and pos/ori limits, etc.
+
+
+class OperationalSpaceController(Controller):
+ """
+ Controller for controlling robot arm via operational space control. Allows position and / or orientation control
+ of the robot's end effector. For detailed information as to the mathematical foundation for this controller, please
+ reference http://khatib.stanford.edu/publications/pdfs/Khatib_1987_RA.pdf
+
+ NOTE: Control input actions can either be taken to be relative to the current position / orientation of the
+ end effector or absolute values. In either case, a given action to this controller is assumed to be of the form:
+ (x, y, z, ax, ay, az) if controlling pos and ori or simply (x, y, z) if only controlling pos
+
+ Args:
+ sim (MjSim): Simulator instance this controller will pull robot state updates from
+
+ eef_name (str): Name of controlled robot arm's end effector (from robot XML)
+
+ joint_indexes (dict): Each key contains sim reference indexes to relevant robot joint information, namely:
+
+ :`'joints'`: list of indexes to relevant robot joints
+ :`'qpos'`: list of indexes to relevant robot joint positions
+ :`'qvel'`: list of indexes to relevant robot joint velocities
+
+ actuator_range (2-tuple of array of float): 2-Tuple (low, high) representing the robot joint actuator range
+
+ input_max (float or Iterable of float): Maximum above which an inputted action will be clipped. Can be either be
+ a scalar (same value for all action dimensions), or a list (specific values for each dimension). If the
+ latter, dimension should be the same as the control dimension for this controller
+
+ input_min (float or Iterable of float): Minimum below which an inputted action will be clipped. Can be either be
+ a scalar (same value for all action dimensions), or a list (specific values for each dimension). If the
+ latter, dimension should be the same as the control dimension for this controller
+
+ output_max (float or Iterable of float): Maximum which defines upper end of scaling range when scaling an input
+ action. Can be either be a scalar (same value for all action dimensions), or a list (specific values for
+ each dimension). If the latter, dimension should be the same as the control dimension for this controller
+
+ output_min (float or Iterable of float): Minimum which defines upper end of scaling range when scaling an input
+ action. Can be either be a scalar (same value for all action dimensions), or a list (specific values for
+ each dimension). If the latter, dimension should be the same as the control dimension for this controller
+
+ kp (float or Iterable of float): positional gain for determining desired torques based upon the pos / ori error.
+ Can be either be a scalar (same value for all action dims), or a list (specific values for each dim)
+
+ damping_ratio (float or Iterable of float): used in conjunction with kp to determine the velocity gain for
+ determining desired torques based upon the joint pos errors. Can be either be a scalar (same value for all
+ action dims), or a list (specific values for each dim)
+
+ impedance_mode (str): Impedance mode with which to run this controller. Options are {"fixed", "variable",
+ "variable_kp"}. If "fixed", the controller will have fixed kp and damping_ratio values as specified by the
+ @kp and @damping_ratio arguments. If "variable", both kp and damping_ratio will now be part of the
+ controller action space, resulting in a total action space of (6 or 3) + 6 * 2. If "variable_kp", only kp
+ will become variable, with damping_ratio fixed at 1 (critically damped). The resulting action space will
+ then be (6 or 3) + 6.
+
+ kp_limits (2-list of float or 2-list of Iterable of floats): Only applicable if @impedance_mode is set to either
+ "variable" or "variable_kp". This sets the corresponding min / max ranges of the controller action space
+ for the varying kp values. Can be either be a 2-list (same min / max for all kp action dims), or a 2-list
+ of list (specific min / max for each kp dim)
+
+ damping_ratio_limits (2-list of float or 2-list of Iterable of floats): Only applicable if @impedance_mode is
+ set to "variable". This sets the corresponding min / max ranges of the controller action space for the
+ varying damping_ratio values. Can be either be a 2-list (same min / max for all damping_ratio action dims),
+ or a 2-list of list (specific min / max for each damping_ratio dim)
+
+ policy_freq (int): Frequency at which actions from the robot policy are fed into this controller
+
+ position_limits (2-list of float or 2-list of Iterable of floats): Limits (m) below and above which the
+ magnitude of a calculated goal eef position will be clipped. Can be either be a 2-list (same min/max value
+ for all cartesian dims), or a 2-list of list (specific min/max values for each dim)
+
+ orientation_limits (2-list of float or 2-list of Iterable of floats): Limits (rad) below and above which the
+ magnitude of a calculated goal eef orientation will be clipped. Can be either be a 2-list
+ (same min/max value for all joint dims), or a 2-list of list (specific min/mx values for each dim)
+
+ interpolator_pos (Interpolator): Interpolator object to be used for interpolating from the current position to
+ the goal position during each timestep between inputted actions
+
+ interpolator_ori (Interpolator): Interpolator object to be used for interpolating from the current orientation
+ to the goal orientation during each timestep between inputted actions
+
+ control_ori (bool): Whether inputted actions will control both pos and ori or exclusively pos
+
+ control_delta (bool): Whether to control the robot using delta or absolute commands (where absolute commands
+ are taken in the world coordinate frame)
+
+ uncouple_pos_ori (bool): Whether to decouple torques meant to control pos and torques meant to control ori
+
+ **kwargs: Does nothing; placeholder to "sink" any additional arguments so that instantiating this controller
+ via an argument dict that has additional extraneous arguments won't raise an error
+
+ Raises:
+ AssertionError: [Invalid impedance mode]
+ """
+
+ def __init__(
+ self,
+ sim,
+ eef_name,
+ joint_indexes,
+ actuator_range,
+ input_max=1,
+ input_min=-1,
+ output_max=(0.05, 0.05, 0.05, 0.5, 0.5, 0.5),
+ output_min=(-0.05, -0.05, -0.05, -0.5, -0.5, -0.5),
+ kp=150,
+ damping_ratio=1,
+ impedance_mode="fixed",
+ kp_limits=(0, 300),
+ damping_ratio_limits=(0, 100),
+ policy_freq=20,
+ position_limits=None,
+ orientation_limits=None,
+ interpolator_pos=None,
+ interpolator_ori=None,
+ control_ori=True,
+ control_delta=True,
+ uncouple_pos_ori=True,
+ **kwargs, # does nothing; used so no error raised when dict is passed with extra terms used previously
+ ):
+
+ super().__init__(
+ sim,
+ eef_name,
+ joint_indexes,
+ actuator_range,
+ )
+ # Determine whether this is pos ori or just pos
+ self.use_ori = control_ori
+
+ # Determine whether we want to use delta or absolute values as inputs
+ self.use_delta = control_delta
+
+ # Control dimension
+ self.control_dim = 6 if self.use_ori else 3
+ self.name_suffix = "POSE" if self.use_ori else "POSITION"
+
+ # input and output max and min (allow for either explicit lists or single numbers)
+ self.input_max = self.nums2array(input_max, self.control_dim)
+ self.input_min = self.nums2array(input_min, self.control_dim)
+ self.output_max = self.nums2array(output_max, self.control_dim)
+ self.output_min = self.nums2array(output_min, self.control_dim)
+
+ # kp kd
+ self.kp = self.nums2array(kp, 6)
+ self.kd = 2 * np.sqrt(self.kp) * damping_ratio
+
+ # kp and kd limits
+ self.kp_min = self.nums2array(kp_limits[0], 6)
+ self.kp_max = self.nums2array(kp_limits[1], 6)
+ self.damping_ratio_min = self.nums2array(damping_ratio_limits[0], 6)
+ self.damping_ratio_max = self.nums2array(damping_ratio_limits[1], 6)
+
+ # Verify the proposed impedance mode is supported
+ assert impedance_mode in IMPEDANCE_MODES, (
+ "Error: Tried to instantiate OSC controller for unsupported "
+ "impedance mode! Inputted impedance mode: {}, Supported modes: {}".format(impedance_mode, IMPEDANCE_MODES)
+ )
+
+ # Impedance mode
+ self.impedance_mode = impedance_mode
+
+ # Add to control dim based on impedance_mode
+ if self.impedance_mode == "variable":
+ self.control_dim += 12
+ elif self.impedance_mode == "variable_kp":
+ self.control_dim += 6
+
+ # limits
+ self.position_limits = np.array(position_limits) if position_limits is not None else position_limits
+ self.orientation_limits = np.array(orientation_limits) if orientation_limits is not None else orientation_limits
+
+ # control frequency
+ self.control_freq = policy_freq
+
+ # interpolator
+ self.interpolator_pos = interpolator_pos
+ self.interpolator_ori = interpolator_ori
+
+ # whether or not pos and ori want to be uncoupled
+ self.uncoupling = uncouple_pos_ori
+
+ # initialize goals based on initial pos / ori
+ self.goal_ori = np.array(self.initial_ee_ori_mat)
+ self.goal_pos = np.array(self.initial_ee_pos)
+
+ self.relative_ori = np.zeros(3)
+ self.ori_ref = None
+
+ def set_goal(self, action, set_pos=None, set_ori=None):
+ """
+ Sets goal based on input @action. If self.impedance_mode is not "fixed", then the input will be parsed into the
+ delta values to update the goal position / pose and the kp and/or damping_ratio values to be immediately updated
+ internally before executing the proceeding control loop.
+
+ Note that @action expected to be in the following format, based on impedance mode!
+
+ :Mode `'fixed'`: [joint pos command]
+ :Mode `'variable'`: [damping_ratio values, kp values, joint pos command]
+ :Mode `'variable_kp'`: [kp values, joint pos command]
+
+ Args:
+ action (Iterable): Desired relative joint position goal state
+ set_pos (Iterable): If set, overrides @action and sets the desired absolute eef position goal state
+ set_ori (Iterable): IF set, overrides @action and sets the desired absolute eef orientation goal state
+ """
+ # Update state
+ self.update()
+
+ # Parse action based on the impedance mode, and update kp / kd as necessary
+ if self.impedance_mode == "variable":
+ damping_ratio, kp, delta = action[:6], action[6:12], action[12:]
+ self.kp = np.clip(kp, self.kp_min, self.kp_max)
+ self.kd = 2 * np.sqrt(self.kp) * np.clip(damping_ratio, self.damping_ratio_min, self.damping_ratio_max)
+ elif self.impedance_mode == "variable_kp":
+ kp, delta = action[:6], action[6:]
+ self.kp = np.clip(kp, self.kp_min, self.kp_max)
+ self.kd = 2 * np.sqrt(self.kp) # critically damped
+ else: # This is case "fixed"
+ delta = action
+
+ # If we're using deltas, interpret actions as such
+ if self.use_delta:
+ if delta is not None:
+ scaled_delta = self.scale_action(delta)
+ if not self.use_ori and set_ori is None:
+ # Set default control for ori since user isn't actively controlling ori
+ set_ori = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, -1.0]])
+ else:
+ scaled_delta = []
+ # Else, interpret actions as absolute values
+ else:
+ if set_pos is None:
+ set_pos = delta[:3]
+ # Set default control for ori if we're only using position control
+ if set_ori is None:
+ set_ori = (
+ T.quat2mat(T.axisangle2quat(delta[3:6]))
+ if self.use_ori
+ else np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, -1.0]])
+ )
+ # No scaling of values since these are absolute values
+ scaled_delta = delta
+
+ # We only want to update goal orientation if there is a valid delta ori value OR if we're using absolute ori
+ # use math.isclose instead of numpy because numpy is slow
+ bools = [0.0 if math.isclose(elem, 0.0) else 1.0 for elem in scaled_delta[3:]]
+ if sum(bools) > 0.0 or set_ori is not None:
+ self.goal_ori = set_goal_orientation(
+ scaled_delta[3:], self.ee_ori_mat, orientation_limit=self.orientation_limits, set_ori=set_ori
+ )
+ self.goal_pos = set_goal_position(
+ scaled_delta[:3], self.ee_pos, position_limit=self.position_limits, set_pos=set_pos
+ )
+
+ if self.interpolator_pos is not None:
+ self.interpolator_pos.set_goal(self.goal_pos)
+
+ if self.interpolator_ori is not None:
+ self.ori_ref = np.array(self.ee_ori_mat) # reference is the current orientation at start
+ self.interpolator_ori.set_goal(
+ orientation_error(self.goal_ori, self.ori_ref)
+ ) # goal is the total orientation error
+ self.relative_ori = np.zeros(3) # relative orientation always starts at 0
+
+ def run_controller(self):
+ """
+ Calculates the torques required to reach the desired setpoint.
+
+ Executes Operational Space Control (OSC) -- either position only or position and orientation.
+
+ A detailed overview of derivation of OSC equations can be seen at:
+ http://khatib.stanford.edu/publications/pdfs/Khatib_1987_RA.pdf
+
+ Returns:
+ np.array: Command torques
+ """
+ # Update state
+ self.update()
+
+ desired_pos = None
+ # Only linear interpolator is currently supported
+ if self.interpolator_pos is not None:
+ # Linear case
+ if self.interpolator_pos.order == 1:
+ desired_pos = self.interpolator_pos.get_interpolated_goal()
+ else:
+ # Nonlinear case not currently supported
+ pass
+ else:
+ desired_pos = np.array(self.goal_pos)
+
+ if self.interpolator_ori is not None:
+ # relative orientation based on difference between current ori and ref
+ self.relative_ori = orientation_error(self.ee_ori_mat, self.ori_ref)
+
+ ori_error = self.interpolator_ori.get_interpolated_goal()
+ else:
+ desired_ori = np.array(self.goal_ori)
+ ori_error = orientation_error(desired_ori, self.ee_ori_mat)
+
+ # Compute desired force and torque based on errors
+ position_error = desired_pos - self.ee_pos
+ vel_pos_error = -self.ee_pos_vel
+
+ # F_r = kp * pos_err + kd * vel_err
+ desired_force = np.multiply(np.array(position_error), np.array(self.kp[0:3])) + np.multiply(
+ vel_pos_error, self.kd[0:3]
+ )
+
+ vel_ori_error = -self.ee_ori_vel
+
+ # Tau_r = kp * ori_err + kd * vel_err
+ desired_torque = np.multiply(np.array(ori_error), np.array(self.kp[3:6])) + np.multiply(
+ vel_ori_error, self.kd[3:6]
+ )
+
+ # Compute nullspace matrix (I - Jbar * J) and lambda matrices ((J * M^-1 * J^T)^-1)
+ lambda_full, lambda_pos, lambda_ori, nullspace_matrix = opspace_matrices(
+ self.mass_matrix, self.J_full, self.J_pos, self.J_ori
+ )
+
+ # Decouples desired positional control from orientation control
+ if self.uncoupling:
+ decoupled_force = np.dot(lambda_pos, desired_force)
+ decoupled_torque = np.dot(lambda_ori, desired_torque)
+ decoupled_wrench = np.concatenate([decoupled_force, decoupled_torque])
+ else:
+ desired_wrench = np.concatenate([desired_force, desired_torque])
+ decoupled_wrench = np.dot(lambda_full, desired_wrench)
+
+ # Gamma (without null torques) = J^T * F + gravity compensations
+ self.torques = np.dot(self.J_full.T, decoupled_wrench) + self.torque_compensation
+
+ # Calculate and add nullspace torques (nullspace_matrix^T * Gamma_null) to final torques
+ # Note: Gamma_null = desired nullspace pose torques, assumed to be positional joint control relative
+ # to the initial joint positions
+ self.torques += nullspace_torques(
+ self.mass_matrix, nullspace_matrix, self.initial_joint, self.joint_pos, self.joint_vel
+ )
+
+ # Always run superclass call for any cleanups at the end
+ super().run_controller()
+
+ return self.torques
+
+ def update_initial_joints(self, initial_joints):
+ # First, update from the superclass method
+ super().update_initial_joints(initial_joints)
+
+ # We also need to reset the goal in case the old goals were set to the initial confguration
+ self.reset_goal()
+
+ def reset_goal(self):
+ """
+ Resets the goal to the current state of the robot
+ """
+ self.goal_ori = np.array(self.ee_ori_mat)
+ self.goal_pos = np.array(self.ee_pos)
+
+ # Also reset interpolators if required
+
+ if self.interpolator_pos is not None:
+ self.interpolator_pos.set_goal(self.goal_pos)
+
+ if self.interpolator_ori is not None:
+ self.ori_ref = np.array(self.ee_ori_mat) # reference is the current orientation at start
+ self.interpolator_ori.set_goal(
+ orientation_error(self.goal_ori, self.ori_ref)
+ ) # goal is the total orientation error
+ self.relative_ori = np.zeros(3) # relative orientation always starts at 0
+
+ @property
+ def control_limits(self):
+ """
+ Returns the limits over this controller's action space, overrides the superclass property
+ Returns the following (generalized for both high and low limits), based on the impedance mode:
+
+ :Mode `'fixed'`: [joint pos command]
+ :Mode `'variable'`: [damping_ratio values, kp values, joint pos command]
+ :Mode `'variable_kp'`: [kp values, joint pos command]
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum action values
+ - (np.array) maximum action values
+ """
+ if self.impedance_mode == "variable":
+ low = np.concatenate([self.damping_ratio_min, self.kp_min, self.input_min])
+ high = np.concatenate([self.damping_ratio_max, self.kp_max, self.input_max])
+ elif self.impedance_mode == "variable_kp":
+ low = np.concatenate([self.kp_min, self.input_min])
+ high = np.concatenate([self.kp_max, self.input_max])
+ else: # This is case "fixed"
+ low, high = self.input_min, self.input_max
+ return low, high
+
+ @property
+ def name(self):
+ return "OSC_" + self.name_suffix
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_collect_and_playback_data.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_collect_and_playback_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bb71c316afe5f0b68e80e52082f83360d3ecec8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_collect_and_playback_data.py
@@ -0,0 +1,107 @@
+"""
+Record trajectory data with the DataCollectionWrapper wrapper and play them back.
+
+Example:
+ $ python demo_collect_and_playback_data.py --environment Lift
+"""
+
+import argparse
+import os
+from glob import glob
+
+import numpy as np
+
+import robosuite as suite
+from robosuite.wrappers import DataCollectionWrapper
+
+
+def collect_random_trajectory(env, timesteps=1000):
+ """Run a random policy to collect trajectories.
+
+ The rollout trajectory is saved to files in npz format.
+ Modify the DataCollectionWrapper wrapper to add new fields or change data formats.
+
+ Args:
+ env (MujocoEnv): environment instance to collect trajectories from
+ timesteps(int): how many environment timesteps to run for a given trajectory
+ """
+
+ env.reset()
+ dof = env.action_dim
+
+ for t in range(timesteps):
+ action = np.random.randn(dof)
+ env.step(action)
+ env.render()
+ if t % 100 == 0:
+ print(t)
+
+
+def playback_trajectory(env, ep_dir):
+ """Playback data from an episode.
+
+ Args:
+ env (MujocoEnv): environment instance to playback trajectory in
+ ep_dir (str): The path to the directory containing data for an episode.
+ """
+
+ # first reload the model from the xml
+ xml_path = os.path.join(ep_dir, "model.xml")
+ with open(xml_path, "r") as f:
+ env.reset_from_xml_string(f.read())
+
+ state_paths = os.path.join(ep_dir, "state_*.npz")
+
+ # read states back, load them one by one, and render
+ t = 0
+ for state_file in sorted(glob(state_paths)):
+ print(state_file)
+ dic = np.load(state_file)
+ states = dic["states"]
+ for state in states:
+ env.sim.set_state_from_flattened(state)
+ env.sim.forward()
+ env.render()
+ t += 1
+ if t % 100 == 0:
+ print(t)
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--environment", type=str, default="Door")
+ parser.add_argument("--robots", nargs="+", type=str, default="Panda", help="Which robot(s) to use in the env")
+ parser.add_argument("--directory", type=str, default="/tmp/")
+ parser.add_argument("--timesteps", type=int, default=1000)
+ args = parser.parse_args()
+
+ # create original environment
+ env = suite.make(
+ args.environment,
+ robots=args.robots,
+ ignore_done=True,
+ use_camera_obs=False,
+ has_renderer=True,
+ has_offscreen_renderer=False,
+ control_freq=20,
+ )
+ data_directory = args.directory
+
+ # wrap the environment with data collection wrapper
+ env = DataCollectionWrapper(env, data_directory)
+
+ # testing to make sure multiple env.reset calls don't create multiple directories
+ env.reset()
+ env.reset()
+ env.reset()
+
+ # collect some data
+ print("Collecting some random data...")
+ collect_random_trajectory(env, timesteps=args.timesteps)
+
+ # playback some data
+ _ = input("Press any key to begin the playback...")
+ print("Playing back the data...")
+ data_directory = env.ep_directory
+ playback_trajectory(env, data_directory)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_control.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca78344fc93d7bacce611b9a3e77f6bbfde6f155
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_control.py
@@ -0,0 +1,161 @@
+"""
+This demo script demonstrates the various functionalities of each controller available within robosuite.
+
+For a given controller, runs through each dimension and executes a perturbation "test_value" from its
+neutral (stationary) value for a certain amount of time "steps_per_action", and then returns to all neutral values
+for time "steps_per_rest" before proceeding with the next action dim.
+
+ E.g.: Given that the expected action space of the Pos / Ori (OSC_POSE) controller (without a gripper) is
+ (dx, dy, dz, droll, dpitch, dyaw), the testing sequence of actions over time will be:
+
+ ***START OF DEMO***
+ ( dx, 0, 0, 0, 0, 0, grip) <-- Translation in x-direction for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, dy, 0, 0, 0, 0, grip) <-- Translation in y-direction for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, 0, dz, 0, 0, 0, grip) <-- Translation in z-direction for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, 0, 0, dr, 0, 0, grip) <-- Rotation in roll (x) axis for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, 0, 0, 0, dp, 0, grip) <-- Rotation in pitch (y) axis for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, 0, 0, 0, 0, dy, grip) <-- Rotation in yaw (z) axis for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ***END OF DEMO***
+
+ Thus the OSC_POSE controller should be expected to sequentially move linearly in the x direction first,
+ then the y direction, then the z direction, and then begin sequentially rotating about its x-axis,
+ then y-axis, then z-axis.
+
+Please reference the documentation of Controllers in the Modules section for an overview of each controller.
+Controllers are expected to behave in a generally controlled manner, according to their control space. The expected
+sequential qualitative behavior during the test is described below for each controller:
+
+* OSC_POSE: Gripper moves sequentially and linearly in x, y, z direction, then sequentially rotates in x-axis, y-axis,
+ z-axis, relative to the global coordinate frame
+* OSC_POSITION: Gripper moves sequentially and linearly in x, y, z direction, relative to the global coordinate frame
+* IK_POSE: Gripper moves sequentially and linearly in x, y, z direction, then sequentially rotates in x-axis, y-axis,
+ z-axis, relative to the local robot end effector frame
+* JOINT_POSITION: Robot Joints move sequentially in a controlled fashion
+* JOINT_VELOCITY: Robot Joints move sequentially in a controlled fashion
+* JOINT_TORQUE: Unlike other controllers, joint torque controller is expected to act rather lethargic, as the
+ "controller" is really just a wrapper for direct torque control of the mujoco actuators. Therefore, a
+ "neutral" value of 0 torque will not guarantee a stable robot when it has non-zero velocity!
+
+"""
+
+import robosuite as suite
+from robosuite.controllers import load_controller_config
+from robosuite.robots import Bimanual
+from robosuite.utils.input_utils import *
+
+if __name__ == "__main__":
+
+ # Create dict to hold options that will be passed to env creation call
+ options = {}
+
+ # print welcome info
+ print("Welcome to robosuite v{}!".format(suite.__version__))
+ print(suite.__logo__)
+
+ # Choose environment and add it to options
+ options["env_name"] = choose_environment()
+
+ # If a multi-arm environment has been chosen, choose configuration and appropriate robot(s)
+ if "TwoArm" in options["env_name"]:
+ # Choose env config and add it to options
+ options["env_configuration"] = choose_multi_arm_config()
+
+ # If chosen configuration was bimanual, the corresponding robot must be Baxter. Else, have user choose robots
+ if options["env_configuration"] == "bimanual":
+ options["robots"] = "Baxter"
+ else:
+ options["robots"] = []
+
+ # Have user choose two robots
+ print("A multiple single-arm configuration was chosen.\n")
+
+ for i in range(2):
+ print("Please choose Robot {}...\n".format(i))
+ options["robots"].append(choose_robots(exclude_bimanual=True))
+
+ # Else, we simply choose a single (single-armed) robot to instantiate in the environment
+ else:
+ options["robots"] = choose_robots(exclude_bimanual=True)
+
+ # Hacky way to grab joint dimension for now
+ joint_dim = 6 if options["robots"] == "UR5e" else 7
+
+ # Choose controller
+ controller_name = choose_controller()
+
+ # Load the desired controller
+ options["controller_configs"] = suite.load_controller_config(default_controller=controller_name)
+
+ # Define the pre-defined controller actions to use (action_dim, num_test_steps, test_value)
+ controller_settings = {
+ "OSC_POSE": [6, 6, 0.1],
+ "OSC_POSITION": [3, 3, 0.1],
+ "IK_POSE": [6, 6, 0.01],
+ "JOINT_POSITION": [joint_dim, joint_dim, 0.2],
+ "JOINT_VELOCITY": [joint_dim, joint_dim, -0.1],
+ "JOINT_TORQUE": [joint_dim, joint_dim, 0.25],
+ }
+
+ # Define variables for each controller test
+ action_dim = controller_settings[controller_name][0]
+ num_test_steps = controller_settings[controller_name][1]
+ test_value = controller_settings[controller_name][2]
+
+ # Define the number of timesteps to use per controller action as well as timesteps in between actions
+ steps_per_action = 75
+ steps_per_rest = 75
+
+ # initialize the task
+ env = suite.make(
+ **options,
+ has_renderer=True,
+ has_offscreen_renderer=False,
+ ignore_done=True,
+ use_camera_obs=False,
+ horizon=(steps_per_action + steps_per_rest) * num_test_steps,
+ control_freq=20,
+ )
+ env.reset()
+ env.viewer.set_camera(camera_id=0)
+
+ # To accommodate for multi-arm settings (e.g.: Baxter), we need to make sure to fill any extra action space
+ # Get total number of arms being controlled
+ n = 0
+ gripper_dim = 0
+ for robot in env.robots:
+ gripper_dim = robot.gripper["right"].dof if isinstance(robot, Bimanual) else robot.gripper.dof
+ n += int(robot.action_dim / (action_dim + gripper_dim))
+
+ # Define neutral value
+ neutral = np.zeros(action_dim + gripper_dim)
+
+ # Keep track of done variable to know when to break loop
+ count = 0
+ # Loop through controller space
+ while count < num_test_steps:
+ action = neutral.copy()
+ for i in range(steps_per_action):
+ if controller_name in {"IK_POSE", "OSC_POSE"} and count > 2:
+ # Set this value to be the scaled axis angle vector
+ vec = np.zeros(3)
+ vec[count - 3] = test_value
+ action[3:6] = vec
+ else:
+ action[count] = test_value
+ total_action = np.tile(action, n)
+ env.step(total_action)
+ env.render()
+ for i in range(steps_per_rest):
+ total_action = np.tile(neutral, n)
+ env.step(total_action)
+ env.render()
+ count += 1
+
+ # Shut down this env before starting the next test
+ env.close()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_device_control.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_device_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..57c8cb1c819e86ea50448c9fb4a43fcb49cd3815
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_device_control.py
@@ -0,0 +1,241 @@
+"""Teleoperate robot with keyboard or SpaceMouse.
+
+***Choose user input option with the --device argument***
+
+Keyboard:
+ We use the keyboard to control the end-effector of the robot.
+ The keyboard provides 6-DoF control commands through various keys.
+ The commands are mapped to joint velocities through an inverse kinematics
+ solver from Bullet physics.
+
+ Note:
+ To run this script with macOS, you must run it with root access.
+
+SpaceMouse:
+
+ We use the SpaceMouse 3D mouse to control the end-effector of the robot.
+ The mouse provides 6-DoF control commands. The commands are mapped to joint
+ velocities through an inverse kinematics solver from Bullet physics.
+
+ The two side buttons of SpaceMouse are used for controlling the grippers.
+
+ SpaceMouse Wireless from 3Dconnexion: https://www.3dconnexion.com/spacemouse_wireless/en/
+ We used the SpaceMouse Wireless in our experiments. The paper below used the same device
+ to collect human demonstrations for imitation learning.
+
+ Reinforcement and Imitation Learning for Diverse Visuomotor Skills
+ Yuke Zhu, Ziyu Wang, Josh Merel, Andrei Rusu, Tom Erez, Serkan Cabi, Saran Tunyasuvunakool,
+ János Kramár, Raia Hadsell, Nando de Freitas, Nicolas Heess
+ RSS 2018
+
+ Note:
+ This current implementation only supports macOS (Linux support can be added).
+ Download and install the driver before running the script:
+ https://www.3dconnexion.com/service/drivers.html
+
+Additionally, --pos_sensitivity and --rot_sensitivity provide relative gains for increasing / decreasing the user input
+device sensitivity
+
+
+***Choose controller with the --controller argument***
+
+Choice of using either inverse kinematics controller (ik) or operational space controller (osc):
+Main difference is that user inputs with ik's rotations are always taken relative to eef coordinate frame, whereas
+ user inputs with osc's rotations are taken relative to global frame (i.e.: static / camera frame of reference).
+
+ Notes:
+ OSC also tends to be more computationally efficient since IK relies on the backend pybullet IK solver.
+
+
+***Choose environment specifics with the following arguments***
+
+ --environment: Task to perform, e.g.: "Lift", "TwoArmPegInHole", "NutAssembly", etc.
+
+ --robots: Robot(s) with which to perform the task. Can be any in
+ {"Panda", "Sawyer", "IIWA", "Jaco", "Kinova3", "UR5e", "Baxter"}. Note that the environments include sanity
+ checks, such that a "TwoArm..." environment will only accept either a 2-tuple of robot names or a single
+ bimanual robot name, according to the specified configuration (see below), and all other environments will
+ only accept a single single-armed robot name
+
+ --config: Exclusively applicable and only should be specified for "TwoArm..." environments. Specifies the robot
+ configuration desired for the task. Options are {"bimanual", "single-arm-parallel", and "single-arm-opposed"}
+
+ -"bimanual": Sets up the environment for a single bimanual robot. Expects a single bimanual robot name to
+ be specified in the --robots argument
+
+ -"single-arm-parallel": Sets up the environment such that two single-armed robots are stationed next to
+ each other facing the same direction. Expects a 2-tuple of single-armed robot names to be specified
+ in the --robots argument.
+
+ -"single-arm-opposed": Sets up the environment such that two single-armed robots are stationed opposed from
+ each other, facing each other from opposite directions. Expects a 2-tuple of single-armed robot names
+ to be specified in the --robots argument.
+
+ --arm: Exclusively applicable and only should be specified for "TwoArm..." environments. Specifies which of the
+ multiple arm eef's to control. The other (passive) arm will remain stationary. Options are {"right", "left"}
+ (from the point of view of the robot(s) facing against the viewer direction)
+
+ --switch-on-grasp: Exclusively applicable and only should be specified for "TwoArm..." environments. If enabled,
+ will switch the current arm being controlled every time the gripper input is pressed
+
+ --toggle-camera-on-grasp: If enabled, gripper input presses will cycle through the available camera angles
+
+Examples:
+
+ For normal single-arm environment:
+ $ python demo_device_control.py --environment PickPlaceCan --robots Sawyer --controller osc
+
+ For two-arm bimanual environment:
+ $ python demo_device_control.py --environment TwoArmLift --robots Baxter --config bimanual --arm left --controller osc
+
+ For two-arm multi single-arm robot environment:
+ $ python demo_device_control.py --environment TwoArmLift --robots Sawyer Sawyer --config single-arm-parallel --controller osc
+
+
+"""
+
+import argparse
+
+import numpy as np
+
+import robosuite as suite
+from robosuite import load_controller_config
+from robosuite.utils.input_utils import input2action
+from robosuite.wrappers import VisualizationWrapper
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--environment", type=str, default="Lift")
+ parser.add_argument("--robots", nargs="+", type=str, default="Panda", help="Which robot(s) to use in the env")
+ parser.add_argument(
+ "--config", type=str, default="single-arm-opposed", help="Specified environment configuration if necessary"
+ )
+ parser.add_argument("--arm", type=str, default="right", help="Which arm to control (eg bimanual) 'right' or 'left'")
+ parser.add_argument("--switch-on-grasp", action="store_true", help="Switch gripper control on gripper action")
+ parser.add_argument("--toggle-camera-on-grasp", action="store_true", help="Switch camera angle on gripper action")
+ parser.add_argument("--controller", type=str, default="osc", help="Choice of controller. Can be 'ik' or 'osc'")
+ parser.add_argument("--device", type=str, default="keyboard")
+ parser.add_argument("--pos-sensitivity", type=float, default=1.0, help="How much to scale position user inputs")
+ parser.add_argument("--rot-sensitivity", type=float, default=1.0, help="How much to scale rotation user inputs")
+ args = parser.parse_args()
+
+ # Import controller config for EE IK or OSC (pos/ori)
+ if args.controller == "ik":
+ controller_name = "IK_POSE"
+ elif args.controller == "osc":
+ controller_name = "OSC_POSE"
+ else:
+ print("Error: Unsupported controller specified. Must be either 'ik' or 'osc'!")
+ raise ValueError
+
+ # Get controller config
+ controller_config = load_controller_config(default_controller=controller_name)
+
+ # Create argument configuration
+ config = {
+ "env_name": args.environment,
+ "robots": args.robots,
+ "controller_configs": controller_config,
+ }
+
+ # Check if we're using a multi-armed environment and use env_configuration argument if so
+ if "TwoArm" in args.environment:
+ config["env_configuration"] = args.config
+ else:
+ args.config = None
+
+ # Create environment
+ env = suite.make(
+ **config,
+ has_renderer=True,
+ has_offscreen_renderer=False,
+ render_camera="agentview",
+ ignore_done=True,
+ use_camera_obs=False,
+ reward_shaping=True,
+ control_freq=20,
+ hard_reset=False,
+ )
+
+ # Wrap this environment in a visualization wrapper
+ env = VisualizationWrapper(env, indicator_configs=None)
+
+ # Setup printing options for numbers
+ np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
+
+ # initialize device
+ if args.device == "keyboard":
+ from robosuite.devices import Keyboard
+
+ device = Keyboard(pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity)
+ env.viewer.add_keypress_callback(device.on_press)
+ elif args.device == "spacemouse":
+ from robosuite.devices import SpaceMouse
+
+ device = SpaceMouse(pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity)
+ else:
+ raise Exception("Invalid device choice: choose either 'keyboard' or 'spacemouse'.")
+
+ while True:
+ # Reset the environment
+ obs = env.reset()
+
+ # Setup rendering
+ cam_id = 0
+ num_cam = len(env.sim.model.camera_names)
+ env.render()
+
+ # Initialize variables that should the maintained between resets
+ last_grasp = 0
+
+ # Initialize device control
+ device.start_control()
+
+ while True:
+ # Set active robot
+ active_robot = env.robots[0] if args.config == "bimanual" else env.robots[args.arm == "left"]
+
+ # Get the newest action
+ action, grasp = input2action(
+ device=device, robot=active_robot, active_arm=args.arm, env_configuration=args.config
+ )
+
+ # If action is none, then this a reset so we should break
+ if action is None:
+ break
+
+ # If the current grasp is active (1) and last grasp is not (-1) (i.e.: grasping input just pressed),
+ # toggle arm control and / or camera viewing angle if requested
+ if last_grasp < 0 < grasp:
+ if args.switch_on_grasp:
+ args.arm = "left" if args.arm == "right" else "right"
+ if args.toggle_camera_on_grasp:
+ cam_id = (cam_id + 1) % num_cam
+ env.viewer.set_camera(camera_id=cam_id)
+ # Update last grasp
+ last_grasp = grasp
+
+ # Fill out the rest of the action space if necessary
+ rem_action_dim = env.action_dim - action.size
+ if rem_action_dim > 0:
+ # Initialize remaining action space
+ rem_action = np.zeros(rem_action_dim)
+ # This is a multi-arm setting, choose which arm to control and fill the rest with zeros
+ if args.arm == "right":
+ action = np.concatenate([action, rem_action])
+ elif args.arm == "left":
+ action = np.concatenate([rem_action, action])
+ else:
+ # Only right and left arms supported
+ print(
+ "Error: Unsupported arm specified -- "
+ "must be either 'right' or 'left'! Got: {}".format(args.arm)
+ )
+ elif rem_action_dim < 0:
+ # We're in an environment with no gripper action space, so trim the action space to be the action dim
+ action = action[: env.action_dim]
+
+ # Step through the simulation and render
+ obs, reward, done, info = env.step(action)
+ env.render()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_domain_randomization.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_domain_randomization.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5efb7c98dfe1c7169db5d520a104f31cb9d788d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_domain_randomization.py
@@ -0,0 +1,74 @@
+"""
+Script to showcase domain randomization functionality.
+"""
+
+import robosuite.macros as macros
+from robosuite.controllers import load_controller_config
+from robosuite.utils.input_utils import *
+from robosuite.wrappers import DomainRandomizationWrapper
+
+# We'll use instance randomization so that entire geom groups are randomized together
+macros.USING_INSTANCE_RANDOMIZATION = True
+
+if __name__ == "__main__":
+
+ # Create dict to hold options that will be passed to env creation call
+ options = {}
+
+ # print welcome info
+ print("Welcome to robosuite v{}!".format(suite.__version__))
+ print(suite.__logo__)
+
+ # Choose environment and add it to options
+ options["env_name"] = choose_environment()
+
+ # If a multi-arm environment has been chosen, choose configuration and appropriate robot(s)
+ if "TwoArm" in options["env_name"]:
+ # Choose env config and add it to options
+ options["env_configuration"] = choose_multi_arm_config()
+
+ # If chosen configuration was bimanual, the corresponding robot must be Baxter. Else, have user choose robots
+ if options["env_configuration"] == "bimanual":
+ options["robots"] = "Baxter"
+ else:
+ options["robots"] = []
+
+ # Have user choose two robots
+ print("A multiple single-arm configuration was chosen.\n")
+
+ for i in range(2):
+ print("Please choose Robot {}...\n".format(i))
+ options["robots"].append(choose_robots(exclude_bimanual=True))
+
+ # Else, we simply choose a single (single-armed) robot to instantiate in the environment
+ else:
+ options["robots"] = choose_robots(exclude_bimanual=True)
+
+ # Choose controller
+ controller_name = choose_controller()
+
+ # Load the desired controller
+ options["controller_configs"] = load_controller_config(default_controller=controller_name)
+
+ # initialize the task
+ env = suite.make(
+ **options,
+ has_renderer=True,
+ has_offscreen_renderer=False,
+ ignore_done=True,
+ use_camera_obs=False,
+ control_freq=20,
+ hard_reset=False, # TODO: Not setting this flag to False brings up a segfault on macos or glfw error on linux
+ )
+ env = DomainRandomizationWrapper(env)
+ env.reset()
+ env.viewer.set_camera(camera_id=0)
+
+ # Get action limits
+ low, high = env.action_spec
+
+ # do visualization
+ for i in range(100):
+ action = np.random.uniform(low, high)
+ obs, reward, done, _ = env.step(action)
+ env.render()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_gripper_interaction.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_gripper_interaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..a677725f19e9c5b5c06911f6855c8d2515bee645
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_gripper_interaction.py
@@ -0,0 +1,133 @@
+"""Gripper interaction demo.
+
+This script illustrates the process of importing grippers into a scene and making it interact
+with the objects with actuators. It also shows how to procedurally generate a scene with the
+APIs of the MJCF utility functions.
+
+Example:
+ $ python run_gripper_test.py
+"""
+
+import xml.etree.ElementTree as ET
+
+from robosuite.models import MujocoWorldBase
+from robosuite.models.arenas.table_arena import TableArena
+from robosuite.models.grippers import PandaGripper, RethinkGripper
+from robosuite.models.objects import BoxObject
+from robosuite.utils import OpenCVRenderer
+from robosuite.utils.binding_utils import MjRenderContextOffscreen, MjSim
+from robosuite.utils.mjcf_utils import new_actuator, new_joint
+
+if __name__ == "__main__":
+
+ # start with an empty world
+ world = MujocoWorldBase()
+
+ # add a table
+ arena = TableArena(table_full_size=(0.4, 0.4, 0.05), table_offset=(0, 0, 1.1), has_legs=False)
+ world.merge(arena)
+
+ # add a gripper
+ gripper = RethinkGripper()
+ # Create another body with a slider joint to which we'll add this gripper
+ gripper_body = ET.Element("body", name="gripper_base")
+ gripper_body.set("pos", "0 0 1.3")
+ gripper_body.set("quat", "0 0 1 0") # flip z
+ gripper_body.append(new_joint(name="gripper_z_joint", type="slide", axis="0 0 1", damping="50"))
+ # Add the dummy body with the joint to the global worldbody
+ world.worldbody.append(gripper_body)
+ # Merge the actual gripper as a child of the dummy body
+ world.merge(gripper, merge_body="gripper_base")
+ # Create a new actuator to control our slider joint
+ world.actuator.append(new_actuator(joint="gripper_z_joint", act_type="position", name="gripper_z", kp="500"))
+
+ # add an object for grasping
+ mujoco_object = BoxObject(
+ name="box", size=[0.02, 0.02, 0.02], rgba=[1, 0, 0, 1], friction=[1, 0.005, 0.0001]
+ ).get_obj()
+ # Set the position of this object
+ mujoco_object.set("pos", "0 0 1.11")
+ # Add our object to the world body
+ world.worldbody.append(mujoco_object)
+
+ # add reference objects for x and y axes
+ x_ref = BoxObject(
+ name="x_ref", size=[0.01, 0.01, 0.01], rgba=[0, 1, 0, 1], obj_type="visual", joints=None
+ ).get_obj()
+ x_ref.set("pos", "0.2 0 1.105")
+ world.worldbody.append(x_ref)
+ y_ref = BoxObject(
+ name="y_ref", size=[0.01, 0.01, 0.01], rgba=[0, 0, 1, 1], obj_type="visual", joints=None
+ ).get_obj()
+ y_ref.set("pos", "0 0.2 1.105")
+ world.worldbody.append(y_ref)
+
+ # start simulation
+ model = world.get_model(mode="mujoco")
+
+ sim = MjSim(model)
+ viewer = OpenCVRenderer(sim)
+ render_context = MjRenderContextOffscreen(sim, device_id=-1)
+ sim.add_render_context(render_context)
+
+ sim_state = sim.get_state()
+
+ # for gravity correction
+ gravity_corrected = ["gripper_z_joint"]
+ _ref_joint_vel_indexes = [sim.model.get_joint_qvel_addr(x) for x in gravity_corrected]
+
+ # Set gripper parameters
+ gripper_z_id = sim.model.actuator_name2id("gripper_z")
+ gripper_z_low = 0.07
+ gripper_z_high = -0.02
+ gripper_z_is_low = False
+
+ gripper_jaw_ids = [sim.model.actuator_name2id(x) for x in gripper.actuators]
+ gripper_open = [-0.0115, 0.0115]
+ gripper_closed = [0.020833, -0.020833]
+ gripper_is_closed = True
+
+ # hardcode sequence for gripper looping trajectory
+ seq = [(False, False), (True, False), (True, True), (False, True)]
+
+ sim.set_state(sim_state)
+ step = 0
+ T = 500
+ while True:
+ if step % 100 == 0:
+ print("step: {}".format(step))
+
+ # Get contact information
+ for contact in sim.data.contact[0 : sim.data.ncon]:
+
+ geom_name1 = sim.model.geom_id2name(contact.geom1)
+ geom_name2 = sim.model.geom_id2name(contact.geom2)
+ if geom_name1 == "floor" and geom_name2 == "floor":
+ continue
+
+ print("geom1: {}, geom2: {}".format(geom_name1, geom_name2))
+ print("contact id {}".format(id(contact)))
+ print("friction: {}".format(contact.friction))
+ print("normal: {}".format(contact.frame[0:3]))
+
+ # Iterate through gripping trajectory
+ if step % T == 0:
+ plan = seq[int(step / T) % len(seq)]
+ gripper_z_is_low, gripper_is_closed = plan
+ print("changing plan: gripper low: {}, gripper closed {}".format(gripper_z_is_low, gripper_is_closed))
+
+ # Control gripper
+ if gripper_z_is_low:
+ sim.data.ctrl[gripper_z_id] = gripper_z_low
+ else:
+ sim.data.ctrl[gripper_z_id] = gripper_z_high
+ if gripper_is_closed:
+ sim.data.ctrl[gripper_jaw_ids] = gripper_closed
+ else:
+ sim.data.ctrl[gripper_jaw_ids] = gripper_open
+
+ # Step through sim
+ sim.step()
+ sim.data.qfrc_applied[_ref_joint_vel_indexes] = sim.data.qfrc_bias[_ref_joint_vel_indexes]
+ viewer.render()
+ step += 1
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_gripper_selection.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_gripper_selection.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f87a800b83b7d7754286d7334d9b9abf3c1bf12
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_gripper_selection.py
@@ -0,0 +1,45 @@
+"""
+This script shows you how to select gripper for an environment.
+This is controlled by gripper_type keyword argument.
+"""
+import numpy as np
+
+import robosuite as suite
+from robosuite import ALL_GRIPPERS
+
+if __name__ == "__main__":
+
+ for gripper in ALL_GRIPPERS:
+
+ # Notify user which gripper we're currently using
+ print("Using gripper {}...".format(gripper))
+
+ # create environment with selected grippers
+ env = suite.make(
+ "Lift",
+ robots="Panda",
+ gripper_types=gripper,
+ has_renderer=True, # make sure we can render to the screen
+ has_offscreen_renderer=False, # not needed since not using pixel obs
+ use_camera_obs=False, # do not use pixel observations
+ control_freq=50, # control should happen fast enough so that simulation looks smoother
+ camera_names="frontview",
+ )
+
+ # Reset the env
+ env.reset()
+
+ # Get action limits
+ low, high = env.action_spec
+
+ # Run random policy
+ for t in range(100):
+ env.render()
+ action = np.random.uniform(low, high)
+ observation, reward, done, info = env.step(action)
+ if done:
+ print("Episode finished after {} timesteps".format(t + 1))
+ break
+
+ # close window
+ env.close()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_gym_functionality.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_gym_functionality.py
new file mode 100644
index 0000000000000000000000000000000000000000..381733676c140e168791120702453c0117e93498
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_gym_functionality.py
@@ -0,0 +1,61 @@
+"""
+This script shows how to adapt an environment to be compatible
+with the Gymnasium API. This is useful when using
+learning pipelines that require supporting these APIs.
+
+For instance, this can be used with OpenAI Baselines
+(https://github.com/openai/baselines) to train agents
+with RL.
+
+
+We base this script off of some code snippets found
+in the "Basic Usage" section of the Gymnasium documentation
+
+The following snippet was used to demo basic functionality.
+
+ import gymnasium as gym
+ env = gym.make("LunarLander-v2", render_mode="human")
+ observation, info = env.reset()
+
+ for _ in range(1000):
+ action = env.action_space.sample() # agent policy that uses the observation and info
+ observation, reward, terminated, truncated, info = env.step(action)
+ if terminated or truncated:
+ observation, info = env.reset()
+ env.close()
+
+To adapt our APIs to be compatible with OpenAI Gym's style, this script
+demonstrates how this can be easily achieved by using the GymWrapper.
+"""
+
+import robosuite as suite
+from robosuite.wrappers import GymWrapper
+
+if __name__ == "__main__":
+
+ # Notice how the environment is wrapped by the wrapper
+ env = GymWrapper(
+ suite.make(
+ "Lift",
+ robots="Sawyer", # use Sawyer robot
+ use_camera_obs=False, # do not use pixel observations
+ has_offscreen_renderer=False, # not needed since not using pixel obs
+ has_renderer=True, # make sure we can render to the screen
+ reward_shaping=True, # use dense rewards
+ control_freq=20, # control should happen fast enough so that simulation looks smooth
+ )
+ )
+
+ env.reset(seed=0)
+
+ for i_episode in range(20):
+ observation = env.reset()
+ for t in range(500):
+ env.render()
+ action = env.action_space.sample()
+ observation, reward, terminated, truncated, info = env.step(action)
+ if terminated or truncated:
+ print("Episode finished after {} timesteps".format(t + 1))
+ observation, info = env.reset()
+ env.close()
+ break
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_nvisii_modalities.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_nvisii_modalities.py
new file mode 100644
index 0000000000000000000000000000000000000000..37728c561c01bbd554eef3f42875c356c2bf19f5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_nvisii_modalities.py
@@ -0,0 +1,106 @@
+"""
+Dumps video of the modality specified from the renderer.
+"""
+
+import argparse
+
+import imageio
+import matplotlib.cm
+import numpy as np
+
+import robosuite as suite
+import robosuite.macros as macros
+from robosuite.controllers import load_controller_config
+from robosuite.renderers import load_renderer_config
+from robosuite.utils.input_utils import *
+
+if __name__ == "__main__":
+
+ """
+ Registered environments: Lift, Stack, NutAssembly, NutAssemblySingle, NutAssemblySquare, NutAssemblyRound,
+ PickPlace, PickPlaceSingle, PickPlaceMilk, PickPlaceBread, PickPlaceCereal,
+ PickPlaceCan, Door, Wipe, TwoArmLift, TwoArmPegInHole, TwoArmHandover
+
+ Possible robots: Baxter, IIWA, Jaco, Kinova3, Panda, Sawyer, UR5e
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--vision-modality",
+ type=str,
+ default="rgb",
+ help="Modality to render. Could be set to `depth`, `normal`, `segmentation`, or `rgb`",
+ )
+
+ args = parser.parse_args()
+
+ options = {}
+
+ # print welcome info
+ print("Welcome to robosuite v{}!".format(suite.__version__))
+ print(suite.__logo__)
+
+ options["env_name"] = choose_environment()
+
+ # If a multi-arm environment has been chosen, choose configuration and appropriate robot(s)
+ if "TwoArm" in options["env_name"]:
+ # Choose env config and add it to options
+ options["env_configuration"] = choose_multi_arm_config()
+
+ # If chosen configuration was bimanual, the corresponding robot must be Baxter. Else, have user choose robots
+ if options["env_configuration"] == "bimanual":
+ options["robots"] = "Baxter"
+ else:
+ options["robots"] = []
+
+ # Have user choose two robots
+ print("A multiple single-arm configuration was chosen.\n")
+
+ for i in range(2):
+ print("Please choose Robot {}...\n".format(i))
+ options["robots"].append(choose_robots(exclude_bimanual=True))
+
+ # Else, we simply choose a single (single-armed) robot to instantiate in the environment
+ else:
+ options["robots"] = choose_robots(exclude_bimanual=True)
+
+ # Load the desired controller
+ options["controller_configs"] = load_controller_config(default_controller="OSC_POSE")
+
+ # change renderer config
+ config = load_renderer_config("nvisii")
+
+ if args.vision_modality == "rgb":
+ config["vision_modalities"] = None
+ if args.vision_modality == "segmentation":
+ config["vision_modalities"] = "segmentation"
+ if args.vision_modality == "depth":
+ config["vision_modalities"] = "depth"
+ if args.vision_modality == "normal":
+ config["vision_modalities"] = "normal"
+
+ env = suite.make(
+ **options,
+ has_renderer=False, # no on-screen renderer
+ has_offscreen_renderer=False, # no off-screen renderer
+ ignore_done=True,
+ use_camera_obs=False, # no camera observations
+ control_freq=20,
+ renderer="nvisii",
+ renderer_config=config,
+ camera_segmentations="element" if config["vision_modalities"] == "segmentation" else None,
+ )
+
+ env.reset()
+
+ low, high = env.action_spec
+
+ timesteps = 300
+ for i in range(timesteps):
+ action = np.random.uniform(low, high)
+ obs, reward, done, _ = env.step(action)
+
+ if i % 100 == 0:
+ env.render()
+
+ env.close_renderer()
+ print("Done.")
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_random_action.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_random_action.py
new file mode 100644
index 0000000000000000000000000000000000000000..88f196758fee9f53cb16b8fceb9c7efe6e37286b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_random_action.py
@@ -0,0 +1,63 @@
+from robosuite.controllers import load_controller_config
+from robosuite.utils.input_utils import *
+
+if __name__ == "__main__":
+
+ # Create dict to hold options that will be passed to env creation call
+ options = {}
+
+ # print welcome info
+ print("Welcome to robosuite v{}!".format(suite.__version__))
+ print(suite.__logo__)
+
+ # Choose environment and add it to options
+ options["env_name"] = choose_environment()
+
+ # If a multi-arm environment has been chosen, choose configuration and appropriate robot(s)
+ if "TwoArm" in options["env_name"]:
+ # Choose env config and add it to options
+ options["env_configuration"] = choose_multi_arm_config()
+
+ # If chosen configuration was bimanual, the corresponding robot must be Baxter. Else, have user choose robots
+ if options["env_configuration"] == "bimanual":
+ options["robots"] = "Baxter"
+ else:
+ options["robots"] = []
+
+ # Have user choose two robots
+ print("A multiple single-arm configuration was chosen.\n")
+
+ for i in range(2):
+ print("Please choose Robot {}...\n".format(i))
+ options["robots"].append(choose_robots(exclude_bimanual=True))
+
+ # Else, we simply choose a single (single-armed) robot to instantiate in the environment
+ else:
+ options["robots"] = choose_robots(exclude_bimanual=True)
+
+ # Choose controller
+ controller_name = choose_controller()
+
+ # Load the desired controller
+ options["controller_configs"] = load_controller_config(default_controller=controller_name)
+
+ # initialize the task
+ env = suite.make(
+ **options,
+ has_renderer=True,
+ has_offscreen_renderer=False,
+ ignore_done=True,
+ use_camera_obs=False,
+ control_freq=20,
+ )
+ env.reset()
+ env.viewer.set_camera(camera_id=0)
+
+ # Get action limits
+ low, high = env.action_spec
+
+ # do visualization
+ for i in range(10000):
+ action = np.random.uniform(low, high)
+ obs, reward, done, _ = env.step(action)
+ env.render()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_renderers.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_renderers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6ae44ca6f0fb08953a5921ae173054c0b1930b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_renderers.py
@@ -0,0 +1,107 @@
+import argparse
+import json
+
+import numpy as np
+
+import robosuite as suite
+import robosuite.utils.transform_utils as T
+from robosuite.controllers import load_controller_config
+from robosuite.renderers import load_renderer_config
+from robosuite.utils.input_utils import *
+
+
+def str2bool(v):
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
+if __name__ == "__main__":
+
+ """
+ Registered environments: Lift, Stack, NutAssembly, NutAssemblySingle, NutAssemblySquare, NutAssemblyRound,
+ PickPlace, PickPlaceSingle, PickPlaceMilk, PickPlaceBread, PickPlaceCereal,
+ PickPlaceCan, Door, Wipe, TwoArmLift, TwoArmPegInHole, TwoArmHandover
+
+ Possible robots: Baxter, IIWA, Jaco, Kinova3, Panda, Sawyer, UR5e
+ """
+
+ options = {}
+
+ # print welcome info
+ print("Welcome to robosuite v{}!".format(suite.__version__))
+ print(suite.__logo__)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--renderer", type=str, default="mujoco", help="Valid options include mujoco, and nvisii")
+
+ args = parser.parse_args()
+ renderer = args.renderer
+
+ options["env_name"] = choose_environment()
+
+ # If a multi-arm environment has been chosen, choose configuration and appropriate robot(s)
+ if "TwoArm" in options["env_name"]:
+ # Choose env config and add it to options
+ options["env_configuration"] = choose_multi_arm_config()
+
+ # If chosen configuration was bimanual, the corresponding robot must be Baxter. Else, have user choose robots
+ if options["env_configuration"] == "bimanual":
+ options["robots"] = "Baxter"
+ else:
+ options["robots"] = []
+
+ # Have user choose two robots
+ print("A multiple single-arm configuration was chosen.\n")
+
+ for i in range(2):
+ print("Please choose Robot {}...\n".format(i))
+ options["robots"].append(choose_robots(exclude_bimanual=True))
+
+ # Else, we simply choose a single (single-armed) robot to instantiate in the environment
+ else:
+ options["robots"] = choose_robots(exclude_bimanual=True)
+
+ # Choose controller
+ controller_name = choose_controller()
+
+ # Load the desired controller
+ options["controller_configs"] = load_controller_config(default_controller=controller_name)
+
+ env = suite.make(
+ **options,
+ has_renderer=False if renderer != "mujoco" else True, # no on-screen renderer
+ has_offscreen_renderer=False, # no off-screen renderer
+ ignore_done=True,
+ use_camera_obs=False, # no camera observations
+ control_freq=20,
+ renderer=renderer,
+ )
+
+ env.reset()
+
+ low, high = env.action_spec
+
+ if renderer == "nvisii":
+
+ timesteps = 300
+ for i in range(timesteps):
+ action = np.random.uniform(low, high)
+ obs, reward, done, _ = env.step(action)
+
+ if i % 100 == 0:
+ env.render()
+
+ else:
+
+ # do visualization
+ for i in range(10000):
+ action = np.random.uniform(low, high)
+ obs, reward, done, _ = env.step(action)
+ env.render()
+
+ env.close_renderer()
+ print("Done.")
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_segmentation.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..54c25b962ead4945e55f0fe87a83d8e55d9fbcdb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_segmentation.py
@@ -0,0 +1,118 @@
+"""
+Play random actions in an environment and render a video that demonstrates segmentation.
+"""
+import argparse
+import colorsys
+import json
+import random
+
+import imageio
+import matplotlib.cm as cm
+import numpy as np
+from PIL import Image
+
+import robosuite as suite
+from robosuite.controllers import load_controller_config
+
+
+def randomize_colors(N, bright=True):
+ """
+ Modified from https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/visualize.py#L59
+ Generate random colors.
+ To get visually distinct colors, generate them in HSV space then
+ convert to RGB.
+ """
+ brightness = 1.0 if bright else 0.5
+ hsv = [(1.0 * i / N, 1, brightness) for i in range(N)]
+ colors = np.array(list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)))
+ rstate = np.random.RandomState(seed=20)
+ np.random.shuffle(colors)
+ return colors
+
+
+def segmentation_to_rgb(seg_im, random_colors=False):
+ """
+ Helper function to visualize segmentations as RGB frames.
+ NOTE: assumes that geom IDs go up to 255 at most - if not,
+ multiple geoms might be assigned to the same color.
+ """
+ # ensure all values lie within [0, 255]
+ seg_im = np.mod(seg_im, 256)
+
+ if random_colors:
+ colors = randomize_colors(N=256, bright=True)
+ return (255.0 * colors[seg_im]).astype(np.uint8)
+ else:
+ # deterministic shuffling of values to map each geom ID to a random int in [0, 255]
+ rstate = np.random.RandomState(seed=8)
+ inds = np.arange(256)
+ rstate.shuffle(inds)
+
+ # use @inds to map each geom ID to a color
+ return (255.0 * cm.rainbow(inds[seg_im], 3)).astype(np.uint8)[..., :3]
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--video-path", type=str, default="/tmp/video.mp4", help="Path to video file")
+ parser.add_argument("--random-colors", action="store_true", help="Radnomize segmentation colors")
+ parser.add_argument("--segmentation-level", type=str, default="element", help="instance, class, or element")
+ args = parser.parse_args()
+
+ # Create dict to hold options that will be passed to env creation call
+ options = {}
+
+ # Choose environment and add it to options
+ options["env_name"] = "TwoArmHandover"
+ options["robots"] = ["Panda", "Panda"]
+
+ # Choose controller
+ controller_name = "OSC_POSE"
+
+ # Choose camera
+ camera = "frontview"
+
+ # Choose segmentation type
+ segmentation_level = args.segmentation_level # Options are {instance, class, element}
+
+ # Load the desired controller
+ options["controller_configs"] = load_controller_config(default_controller=controller_name)
+
+ # initialize the task
+ env = suite.make(
+ **options,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ ignore_done=True,
+ use_camera_obs=True,
+ control_freq=20,
+ camera_names=camera,
+ camera_segmentations=segmentation_level,
+ camera_heights=512,
+ camera_widths=512,
+ )
+ env.reset()
+
+ video_writer = imageio.get_writer(args.video_path, fps=20)
+
+ # Get action limits
+ low, high = env.action_spec
+
+ # do visualization
+ for i in range(100):
+ action = 0.5 * np.random.uniform(low, high)
+ obs, reward, done, _ = env.step(action)
+
+ video_img = obs[f"{camera}_segmentation_{segmentation_level}"].squeeze(-1)[::-1]
+ np.savetxt("/tmp/seg_{}.txt".format(i), video_img, fmt="%.2f")
+ video_img = segmentation_to_rgb(video_img, args.random_colors)
+ video_writer.append_data(video_img)
+
+ image = Image.fromarray(video_img)
+ image.save("/tmp/seg_{}.png".format(i))
+ if i % 5 == 0:
+ print("Step #{} / 100".format(i))
+
+ video_writer.close()
+ print("Video saved to {}".format(args.video_path))
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_sensor_corruption.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_sensor_corruption.py
new file mode 100644
index 0000000000000000000000000000000000000000..8207741e14cc4d99efff3fd379e302f935f91fbe
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_sensor_corruption.py
@@ -0,0 +1,264 @@
+"""Sensor Corruption Demo.
+
+This script provides an example of using the Observables functionality to implement a corrupted sensor
+(corruption + delay).
+Images will be rendered in a delayed fashion, such that the user will have seemingly delayed actions
+
+This is a modified version of the demo_device_control teleoperation script.
+
+Example:
+ $ python demo_sensor_corruption.py --environment Stack --robots Panda --delay 0.05 --corruption 5.0 --toggle-corruption-on-grasp
+"""
+
+import argparse
+import sys
+
+import cv2
+import numpy as np
+
+import robosuite as suite
+from robosuite import load_controller_config
+from robosuite.utils.input_utils import input2action
+from robosuite.utils.observables import Observable, create_gaussian_noise_corrupter, create_uniform_sampled_delayer
+from robosuite.wrappers import VisualizationWrapper
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--environment", type=str, default="Lift")
+ parser.add_argument("--robots", nargs="+", type=str, default="Panda", help="Which robot(s) to use in the env")
+ parser.add_argument(
+ "--config", type=str, default="single-arm-opposed", help="Specified environment configuration if necessary"
+ )
+ parser.add_argument("--arm", type=str, default="right", help="Which arm to control (eg bimanual) 'right' or 'left'")
+ parser.add_argument("--switch-on-grasp", action="store_true", help="Switch gripper control on gripper action")
+ parser.add_argument(
+ "--toggle-corruption-on-grasp", action="store_true", help="Toggle corruption ON / OFF on gripper action"
+ )
+ parser.add_argument("--controller", type=str, default="osc", help="Choice of controller. Can be 'ik' or 'osc'")
+ parser.add_argument("--device", type=str, default="keyboard")
+ parser.add_argument("--pos-sensitivity", type=float, default=1.0, help="How much to scale position user inputs")
+ parser.add_argument("--rot-sensitivity", type=float, default=1.0, help="How much to scale rotation user inputs")
+ parser.add_argument("--delay", type=float, default=0.04, help="average delay to use (sec)")
+ parser.add_argument("--corruption", type=float, default=20.0, help="Scale of corruption to use (std dev)")
+ parser.add_argument("--camera", type=str, default="agentview", help="Name of camera to render")
+ parser.add_argument("--width", type=int, default=512)
+ parser.add_argument("--height", type=int, default=384)
+ args = parser.parse_args()
+
+ # Import controller config for EE IK or OSC (pos/ori)
+ if args.controller == "ik":
+ controller_name = "IK_POSE"
+ elif args.controller == "osc":
+ controller_name = "OSC_POSE"
+ else:
+ print("Error: Unsupported controller specified. Must be either 'ik' or 'osc'!")
+ raise ValueError
+
+ # Get controller config
+ controller_config = load_controller_config(default_controller=controller_name)
+
+ # Create argument configuration
+ config = {
+ "env_name": args.environment,
+ "robots": args.robots,
+ "controller_configs": controller_config,
+ }
+
+ # Check if we're using a multi-armed environment and use env_configuration argument if so
+ if "TwoArm" in args.environment:
+ config["env_configuration"] = args.config
+ else:
+ args.config = None
+
+ # Create environment
+ env = suite.make(
+ **config,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ ignore_done=True,
+ camera_names=args.camera,
+ camera_heights=args.height,
+ camera_widths=args.width,
+ use_camera_obs=True,
+ use_object_obs=True,
+ hard_reset=False,
+ )
+
+ # Wrap this environment in a visualization wrapper
+ env = VisualizationWrapper(env, indicator_configs=None)
+
+ # Set shared settings
+ attributes = ["corrupter", "delayer", "sampling_rate"]
+ corruption_mode = 1 # 1 is corruption = ON, 0 is corruption = OFF
+ obs_settings = {}
+
+ # Function to easily modify observable on the fly
+ def modify_obs(obs_name, attrs, mods):
+ for attr, mod in zip(attrs, mods):
+ env.modify_observable(
+ observable_name=obs_name,
+ attribute=attr,
+ modifier=mod,
+ )
+
+ # Add image corruption and delay
+ image_sampling_rate = 10.0
+ image_obs_name = f"{args.camera}_image"
+ image_corrupter = create_gaussian_noise_corrupter(mean=0.0, std=args.corruption, low=0, high=255)
+ image_delayer = create_uniform_sampled_delayer(min_delay=max(0, args.delay - 0.025), max_delay=args.delay + 0.025)
+ image_modifiers = [image_corrupter, image_delayer, image_sampling_rate]
+
+ # Initialize settings
+ modify_obs(obs_name=image_obs_name, attrs=attributes, mods=image_modifiers)
+
+ # Add entry for the corruption / delay settings in dict
+ obs_settings[image_obs_name] = {
+ "attrs": attributes[:2],
+ "mods": lambda: image_modifiers[:2] if corruption_mode else [None, None],
+ }
+
+ # Add proprioception corruption and delay
+ proprio_sampling_rate = 20.0
+ proprio_obs_name = f"{env.robots[0].robot_model.naming_prefix}joint_pos"
+ joint_limits = env.sim.model.jnt_range[env.robots[0]._ref_joint_indexes]
+ joint_range = joint_limits[:, 1] - joint_limits[:, 0]
+ proprio_corrupter = create_gaussian_noise_corrupter(mean=0.0, std=joint_range / 50.0)
+ curr_proprio_delay = 0.0
+ tmp_delayer = create_uniform_sampled_delayer(
+ min_delay=max(0, (args.delay - 0.025) / 2), max_delay=(args.delay + 0.025) / 2
+ )
+
+ # Define delayer to synchronize delay between ground truth and corrupted sensors
+ def proprio_delayer():
+ global curr_proprio_delay
+ curr_proprio_delay = tmp_delayer()
+ return curr_proprio_delay
+
+ # Define function to convert raw delay time to actual sampling delay (in discrete timesteps)
+ def calculate_proprio_delay():
+ base = env.model_timestep
+ return base * round(curr_proprio_delay / base) if corruption_mode else 0.0
+
+ proprio_modifiers = [proprio_corrupter, proprio_delayer, proprio_sampling_rate]
+
+ # We will create a separate "ground truth" delayed proprio observable to track exactly
+ # how much corruption we're getting in real time
+ proprio_sensor = env._observables[proprio_obs_name]._sensor
+ proprio_ground_truth_obs_name = f"{proprio_obs_name}_ground_truth"
+ observable = Observable(
+ name=proprio_ground_truth_obs_name,
+ sensor=proprio_sensor,
+ delayer=lambda: curr_proprio_delay,
+ sampling_rate=proprio_sampling_rate,
+ )
+
+ # Add this observable
+ env.add_observable(observable)
+
+ # We also need to set the normal joint pos observable to be active (not active by default)
+ env.modify_observable(observable_name=proprio_obs_name, attribute="active", modifier=True)
+
+ # Initialize settings
+ modify_obs(obs_name=proprio_obs_name, attrs=attributes, mods=proprio_modifiers)
+
+ # Add entry for the corruption / delay settings in dict
+ obs_settings[proprio_obs_name] = {
+ "attrs": attributes[:2],
+ "mods": lambda: proprio_modifiers[:2] if corruption_mode else [None, None],
+ }
+ obs_settings[proprio_ground_truth_obs_name] = {
+ "attrs": [attributes[1]],
+ "mods": lambda: [lambda: curr_proprio_delay] if corruption_mode else [None],
+ }
+
+ # Setup printing options for numbers
+ np.set_printoptions(precision=3, suppress=True, floatmode="fixed")
+
+ # initialize device
+ if args.device == "keyboard":
+ from robosuite.devices import Keyboard
+
+ device = Keyboard(pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity)
+ elif args.device == "spacemouse":
+ from robosuite.devices import SpaceMouse
+
+ device = SpaceMouse(pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity)
+ else:
+ raise Exception("Invalid device choice: choose either 'keyboard' or 'spacemouse'.")
+
+ while True:
+ # Reset the environment
+ obs = env.reset()
+
+ # Reset corruption mode
+ corruption_mode = 1
+
+ # Initialize variables that should the maintained between resets
+ last_grasp = 0
+
+ # Initialize device control
+ device.start_control()
+
+ while True:
+ # Set active robot
+ active_robot = env.robots[0] if args.config == "bimanual" else env.robots[args.arm == "left"]
+
+ # Get the newest action
+ action, grasp = input2action(
+ device=device, robot=active_robot, active_arm=args.arm, env_configuration=args.config
+ )
+
+ # If action is none, then this a reset so we should break
+ if action is None:
+ break
+
+ # If the current grasp is active (1) and last grasp is not (-1) (i.e.: grasping input just pressed),
+ # toggle arm control and / or corruption if requested
+ if last_grasp < 0 < grasp:
+ if args.switch_on_grasp:
+ args.arm = "left" if args.arm == "right" else "right"
+ if args.toggle_corruption_on_grasp:
+ # Toggle corruption and update observable
+ corruption_mode = 1 - corruption_mode
+ for obs_name, settings in obs_settings.items():
+ modify_obs(obs_name=obs_name, attrs=settings["attrs"], mods=settings["mods"]())
+ # Update last grasp
+ last_grasp = grasp
+
+ # Fill out the rest of the action space if necessary
+ rem_action_dim = env.action_dim - action.size
+ if rem_action_dim > 0:
+ # Initialize remaining action space
+ rem_action = np.zeros(rem_action_dim)
+ # This is a multi-arm setting, choose which arm to control and fill the rest with zeros
+ if args.arm == "right":
+ action = np.concatenate([action, rem_action])
+ elif args.arm == "left":
+ action = np.concatenate([rem_action, action])
+ else:
+ # Only right and left arms supported
+ print(
+ "Error: Unsupported arm specified -- "
+ "must be either 'right' or 'left'! Got: {}".format(args.arm)
+ )
+ elif rem_action_dim < 0:
+ # We're in an environment with no gripper action space, so trim the action space to be the action dim
+ action = action[: env.action_dim]
+
+ # Step through the simulation and render
+ obs, reward, done, info = env.step(action)
+
+ # Calculate and print out stats for proprio observation
+ observed_value = obs[proprio_obs_name]
+ ground_truth_delayed_value = obs[proprio_ground_truth_obs_name]
+ print(
+ f"Observed joint pos: {observed_value}, "
+ f"Corruption: {observed_value - ground_truth_delayed_value}, "
+ f"Delay: {calculate_proprio_delay():.3f} sec"
+ )
+
+ # read camera observation
+ im = np.flip(obs[args.camera + "_image"][..., ::-1], 0).astype(np.uint8)
+
+ cv2.imshow("offscreen render", im)
+ cv2.waitKey(1)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/demos/demo_video_recording.py b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_video_recording.py
new file mode 100644
index 0000000000000000000000000000000000000000..3424f289a9a95ec8029a0a9bee2cdc7f025f09cb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/demos/demo_video_recording.py
@@ -0,0 +1,69 @@
+"""
+Record video of agent episodes with the imageio library.
+This script uses offscreen rendering.
+
+Example:
+ $ python demo_video_recording.py --environment Lift --robots Panda
+"""
+
+import argparse
+
+import imageio
+import numpy as np
+
+import robosuite.macros as macros
+from robosuite import make
+
+# Set the image convention to opencv so that the images are automatically rendered "right side up" when using imageio
+# (which uses opencv convention)
+macros.IMAGE_CONVENTION = "opencv"
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--environment", type=str, default="Stack")
+ parser.add_argument("--robots", nargs="+", type=str, default="Panda", help="Which robot(s) to use in the env")
+ parser.add_argument("--camera", type=str, default="agentview", help="Name of camera to render")
+ parser.add_argument("--video_path", type=str, default="video.mp4")
+ parser.add_argument("--timesteps", type=int, default=500)
+ parser.add_argument("--height", type=int, default=512)
+ parser.add_argument("--width", type=int, default=512)
+ parser.add_argument("--skip_frame", type=int, default=1)
+ args = parser.parse_args()
+
+ # initialize an environment with offscreen renderer
+ env = make(
+ args.environment,
+ args.robots,
+ has_renderer=False,
+ ignore_done=True,
+ use_camera_obs=True,
+ use_object_obs=False,
+ camera_names=args.camera,
+ camera_heights=args.height,
+ camera_widths=args.width,
+ )
+
+ obs = env.reset()
+ ndim = env.action_dim
+
+ # create a video writer with imageio
+ writer = imageio.get_writer(args.video_path, fps=20)
+
+ frames = []
+ for i in range(args.timesteps):
+
+ # run a uniformly random agent
+ action = 0.5 * np.random.randn(ndim)
+ obs, reward, done, info = env.step(action)
+
+ # dump a frame from every K frames
+ if i % args.skip_frame == 0:
+ frame = obs[args.camera + "_image"]
+ writer.append_data(frame)
+ print("Saving frame #{}".format(i))
+
+ if done:
+ break
+
+ writer.close()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/devices/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/devices/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddabf0335463d61f23889c261ec8b5f11c707c24
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/devices/__init__.py
@@ -0,0 +1,11 @@
+from .device import Device
+from .keyboard import Keyboard
+
+try:
+ from .spacemouse import SpaceMouse
+except ImportError:
+ print(
+ """Unable to load module hid, required to interface with SpaceMouse.\n
+ Only macOS is officially supported. Install the additional\n
+ requirements with `pip install -r requirements-extra.txt`"""
+ )
diff --git a/phantom/submodules/phantom-robosuite/robosuite/devices/device.py b/phantom/submodules/phantom-robosuite/robosuite/devices/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..918523b751d5a1aad0e0310b3a475837d20c018f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/devices/device.py
@@ -0,0 +1,21 @@
+import abc # for abstract base class definitions
+
+
+class Device(metaclass=abc.ABCMeta):
+ """
+ Base class for all robot controllers.
+ Defines basic interface for all controllers to adhere to.
+ """
+
+ @abc.abstractmethod
+ def start_control(self):
+ """
+ Method that should be called externally before controller can
+ start receiving commands.
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def get_controller_state(self):
+ """Returns the current state of the device, a dictionary of pos, orn, grasp, and reset."""
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/devices/keyboard.py b/phantom/submodules/phantom-robosuite/robosuite/devices/keyboard.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb37648d61f6efe5d9d1946cc155effccad65d35
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/devices/keyboard.py
@@ -0,0 +1,170 @@
+"""
+Driver class for Keyboard controller.
+"""
+
+import numpy as np
+from pynput.keyboard import Controller, Key, Listener
+
+from robosuite.devices import Device
+from robosuite.utils.transform_utils import rotation_matrix
+
+
+class Keyboard(Device):
+ """
+ A minimalistic driver class for a Keyboard.
+ Args:
+ pos_sensitivity (float): Magnitude of input position command scaling
+ rot_sensitivity (float): Magnitude of scale input rotation commands scaling
+ """
+
+ def __init__(self, pos_sensitivity=1.0, rot_sensitivity=1.0):
+
+ self._display_controls()
+ self._reset_internal_state()
+
+ self._reset_state = 0
+ self._enabled = False
+ self._pos_step = 0.05
+
+ self.pos_sensitivity = pos_sensitivity
+ self.rot_sensitivity = rot_sensitivity
+
+ # make a thread to listen to keyboard and register our callback functions
+ self.listener = Listener(on_press=self.on_press, on_release=self.on_release)
+
+ # start listening
+ self.listener.start()
+
+ @staticmethod
+ def _display_controls():
+ """
+ Method to pretty print controls.
+ """
+
+ def print_command(char, info):
+ char += " " * (10 - len(char))
+ print("{}\t{}".format(char, info))
+
+ print("")
+ print_command("Keys", "Command")
+ print_command("q", "reset simulation")
+ print_command("spacebar", "toggle gripper (open/close)")
+ print_command("w-a-s-d", "move arm horizontally in x-y plane")
+ print_command("r-f", "move arm vertically")
+ print_command("z-x", "rotate arm about x-axis")
+ print_command("t-g", "rotate arm about y-axis")
+ print_command("c-v", "rotate arm about z-axis")
+ print("")
+
+ def _reset_internal_state(self):
+ """
+ Resets internal state of controller, except for the reset signal.
+ """
+ self.rotation = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]])
+ self.raw_drotation = np.zeros(3) # immediate roll, pitch, yaw delta values from keyboard hits
+ self.last_drotation = np.zeros(3)
+ self.pos = np.zeros(3) # (x, y, z)
+ self.last_pos = np.zeros(3)
+ self.grasp = False
+
+ def start_control(self):
+ """
+ Method that should be called externally before controller can
+ start receiving commands.
+ """
+ self._reset_internal_state()
+ self._reset_state = 0
+ self._enabled = True
+
+ def get_controller_state(self):
+ """
+ Grabs the current state of the keyboard.
+ Returns:
+ dict: A dictionary containing dpos, orn, unmodified orn, grasp, and reset
+ """
+
+ dpos = self.pos - self.last_pos
+ self.last_pos = np.array(self.pos)
+ raw_drotation = (
+ self.raw_drotation - self.last_drotation
+ ) # create local variable to return, then reset internal drotation
+ self.last_drotation = np.array(self.raw_drotation)
+ return dict(
+ dpos=dpos,
+ rotation=self.rotation,
+ raw_drotation=raw_drotation,
+ grasp=int(self.grasp),
+ reset=self._reset_state,
+ )
+
+ def on_press(self, key):
+ """
+ Key handler for key presses.
+ Args:
+ key (str): key that was pressed
+ """
+
+ try:
+ # controls for moving position
+ if key.char == "w":
+ self.pos[0] -= self._pos_step * self.pos_sensitivity # dec x
+ elif key.char == "s":
+ self.pos[0] += self._pos_step * self.pos_sensitivity # inc x
+ elif key.char == "a":
+ self.pos[1] -= self._pos_step * self.pos_sensitivity # dec y
+ elif key.char == "d":
+ self.pos[1] += self._pos_step * self.pos_sensitivity # inc y
+ elif key.char == "f":
+ self.pos[2] -= self._pos_step * self.pos_sensitivity # dec z
+ elif key.char == "r":
+ self.pos[2] += self._pos_step * self.pos_sensitivity # inc z
+
+ # controls for moving orientation
+ elif key.char == "z":
+ drot = rotation_matrix(angle=0.1 * self.rot_sensitivity, direction=[1.0, 0.0, 0.0])[:3, :3]
+ self.rotation = self.rotation.dot(drot) # rotates x
+ self.raw_drotation[1] -= 0.1 * self.rot_sensitivity
+ elif key.char == "x":
+ drot = rotation_matrix(angle=-0.1 * self.rot_sensitivity, direction=[1.0, 0.0, 0.0])[:3, :3]
+ self.rotation = self.rotation.dot(drot) # rotates x
+ self.raw_drotation[1] += 0.1 * self.rot_sensitivity
+ elif key.char == "t":
+ drot = rotation_matrix(angle=0.1 * self.rot_sensitivity, direction=[0.0, 1.0, 0.0])[:3, :3]
+ self.rotation = self.rotation.dot(drot) # rotates y
+ self.raw_drotation[0] += 0.1 * self.rot_sensitivity
+ elif key.char == "g":
+ drot = rotation_matrix(angle=-0.1 * self.rot_sensitivity, direction=[0.0, 1.0, 0.0])[:3, :3]
+ self.rotation = self.rotation.dot(drot) # rotates y
+ self.raw_drotation[0] -= 0.1 * self.rot_sensitivity
+ elif key.char == "c":
+ drot = rotation_matrix(angle=0.1 * self.rot_sensitivity, direction=[0.0, 0.0, 1.0])[:3, :3]
+ self.rotation = self.rotation.dot(drot) # rotates z
+ self.raw_drotation[2] += 0.1 * self.rot_sensitivity
+ elif key.char == "v":
+ drot = rotation_matrix(angle=-0.1 * self.rot_sensitivity, direction=[0.0, 0.0, 1.0])[:3, :3]
+ self.rotation = self.rotation.dot(drot) # rotates z
+ self.raw_drotation[2] -= 0.1 * self.rot_sensitivity
+
+ except AttributeError as e:
+ pass
+
+ def on_release(self, key):
+ """
+ Key handler for key releases.
+ Args:
+ key (str): key that was pressed
+ """
+
+ try:
+ # controls for grasping
+ if key == Key.space:
+ self.grasp = not self.grasp # toggle gripper
+
+ # user-commanded reset
+ elif key.char == "q":
+ self._reset_state = 1
+ self._enabled = False
+ self._reset_internal_state()
+
+ except AttributeError as e:
+ pass
diff --git a/phantom/submodules/phantom-robosuite/robosuite/devices/spacemouse.py b/phantom/submodules/phantom-robosuite/robosuite/devices/spacemouse.py
new file mode 100644
index 0000000000000000000000000000000000000000..604989ff28f8dbc9662097b9f8f6e173b1c6c85d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/devices/spacemouse.py
@@ -0,0 +1,317 @@
+"""Driver class for SpaceMouse controller.
+
+This class provides a driver support to SpaceMouse on macOS.
+In particular, we assume you are using a SpaceMouse Wireless by default.
+
+To set up a new SpaceMouse controller:
+ 1. Download and install driver from https://www.3dconnexion.com/service/drivers.html
+ 2. Install hidapi library through pip
+ (make sure you run uninstall hid first if it is installed).
+ 3. Make sure SpaceMouse is connected before running the script
+ 4. (Optional) Based on the model of SpaceMouse, you might need to change the
+ vendor id and product id that correspond to the device.
+
+For Linux support, you can find open-source Linux drivers and SDKs online.
+ See http://spacenav.sourceforge.net/
+
+"""
+
+import threading
+import time
+from collections import namedtuple
+
+import numpy as np
+
+try:
+ import hid
+except ModuleNotFoundError as exc:
+ raise ImportError(
+ "Unable to load module hid, required to interface with SpaceMouse. "
+ "Only macOS is officially supported. Install the additional "
+ "requirements with `pip install -r requirements-extra.txt`"
+ ) from exc
+
+import robosuite.macros as macros
+from robosuite.devices import Device
+from robosuite.utils.transform_utils import rotation_matrix
+
+AxisSpec = namedtuple("AxisSpec", ["channel", "byte1", "byte2", "scale"])
+
+SPACE_MOUSE_SPEC = {
+ "x": AxisSpec(channel=1, byte1=1, byte2=2, scale=1),
+ "y": AxisSpec(channel=1, byte1=3, byte2=4, scale=-1),
+ "z": AxisSpec(channel=1, byte1=5, byte2=6, scale=-1),
+ "roll": AxisSpec(channel=1, byte1=7, byte2=8, scale=-1),
+ "pitch": AxisSpec(channel=1, byte1=9, byte2=10, scale=-1),
+ "yaw": AxisSpec(channel=1, byte1=11, byte2=12, scale=1),
+}
+
+
+def to_int16(y1, y2):
+ """
+ Convert two 8 bit bytes to a signed 16 bit integer.
+
+ Args:
+ y1 (int): 8-bit byte
+ y2 (int): 8-bit byte
+
+ Returns:
+ int: 16-bit integer
+ """
+ x = (y1) | (y2 << 8)
+ if x >= 32768:
+ x = -(65536 - x)
+ return x
+
+
+def scale_to_control(x, axis_scale=350.0, min_v=-1.0, max_v=1.0):
+ """
+ Normalize raw HID readings to target range.
+
+ Args:
+ x (int): Raw reading from HID
+ axis_scale (float): (Inverted) scaling factor for mapping raw input value
+ min_v (float): Minimum limit after scaling
+ max_v (float): Maximum limit after scaling
+
+ Returns:
+ float: Clipped, scaled input from HID
+ """
+ x = x / axis_scale
+ x = min(max(x, min_v), max_v)
+ return x
+
+
+def convert(b1, b2):
+ """
+ Converts SpaceMouse message to commands.
+
+ Args:
+ b1 (int): 8-bit byte
+ b2 (int): 8-bit byte
+
+ Returns:
+ float: Scaled value from Spacemouse message
+ """
+ return scale_to_control(to_int16(b1, b2))
+
+
+class SpaceMouse(Device):
+ """
+ A minimalistic driver class for SpaceMouse with HID library.
+
+ Note: Use hid.enumerate() to view all USB human interface devices (HID).
+ Make sure SpaceMouse is detected before running the script.
+ You can look up its vendor/product id from this method.
+
+ Args:
+ vendor_id (int): HID device vendor id
+ product_id (int): HID device product id
+ pos_sensitivity (float): Magnitude of input position command scaling
+ rot_sensitivity (float): Magnitude of scale input rotation commands scaling
+ """
+
+ def __init__(
+ self,
+ vendor_id=macros.SPACEMOUSE_VENDOR_ID,
+ product_id=macros.SPACEMOUSE_PRODUCT_ID,
+ pos_sensitivity=1.0,
+ rot_sensitivity=1.0,
+ ):
+
+ print("Opening SpaceMouse device")
+ self.vendor_id = vendor_id
+ self.product_id = product_id
+ self.device = hid.device()
+ self.device.open(self.vendor_id, self.product_id) # SpaceMouse
+
+ self.pos_sensitivity = pos_sensitivity
+ self.rot_sensitivity = rot_sensitivity
+
+ print("Manufacturer: %s" % self.device.get_manufacturer_string())
+ print("Product: %s" % self.device.get_product_string())
+
+ # 6-DOF variables
+ self.x, self.y, self.z = 0, 0, 0
+ self.roll, self.pitch, self.yaw = 0, 0, 0
+
+ self._display_controls()
+
+ self.single_click_and_hold = False
+
+ self._control = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
+ self._reset_state = 0
+ self.rotation = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]])
+ self._enabled = False
+
+ # launch a new listener thread to listen to SpaceMouse
+ self.thread = threading.Thread(target=self.run)
+ self.thread.daemon = True
+ self.thread.start()
+
+ @staticmethod
+ def _display_controls():
+ """
+ Method to pretty print controls.
+ """
+
+ def print_command(char, info):
+ char += " " * (30 - len(char))
+ print("{}\t{}".format(char, info))
+
+ print("")
+ print_command("Control", "Command")
+ print_command("Right button", "reset simulation")
+ print_command("Left button (hold)", "close gripper")
+ print_command("Move mouse laterally", "move arm horizontally in x-y plane")
+ print_command("Move mouse vertically", "move arm vertically")
+ print_command("Twist mouse about an axis", "rotate arm about a corresponding axis")
+ print("")
+
+ def _reset_internal_state(self):
+ """
+ Resets internal state of controller, except for the reset signal.
+ """
+ self.rotation = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]])
+ # Reset 6-DOF variables
+ self.x, self.y, self.z = 0, 0, 0
+ self.roll, self.pitch, self.yaw = 0, 0, 0
+ # Reset control
+ self._control = np.zeros(6)
+ # Reset grasp
+ self.single_click_and_hold = False
+
+ def start_control(self):
+ """
+ Method that should be called externally before controller can
+ start receiving commands.
+ """
+ self._reset_internal_state()
+ self._reset_state = 0
+ self._enabled = True
+
+ def get_controller_state(self):
+ """
+ Grabs the current state of the 3D mouse.
+
+ Returns:
+ dict: A dictionary containing dpos, orn, unmodified orn, grasp, and reset
+ """
+ dpos = self.control[:3] * 0.005 * self.pos_sensitivity
+ roll, pitch, yaw = self.control[3:] * 0.005 * self.rot_sensitivity
+
+ # convert RPY to an absolute orientation
+ drot1 = rotation_matrix(angle=-pitch, direction=[1.0, 0, 0], point=None)[:3, :3]
+ drot2 = rotation_matrix(angle=roll, direction=[0, 1.0, 0], point=None)[:3, :3]
+ drot3 = rotation_matrix(angle=yaw, direction=[0, 0, 1.0], point=None)[:3, :3]
+
+ self.rotation = self.rotation.dot(drot1.dot(drot2.dot(drot3)))
+
+ return dict(
+ dpos=dpos,
+ rotation=self.rotation,
+ raw_drotation=np.array([roll, pitch, yaw]),
+ grasp=self.control_gripper,
+ reset=self._reset_state,
+ )
+
+ def run(self):
+ """Listener method that keeps pulling new messages."""
+
+ t_last_click = -1
+
+ while True:
+ d = self.device.read(13)
+ if d is not None and self._enabled:
+
+ if self.product_id == 50741:
+ ## logic for older spacemouse model
+
+ if d[0] == 1: ## readings from 6-DoF sensor
+ self.y = convert(d[1], d[2])
+ self.x = convert(d[3], d[4])
+ self.z = convert(d[5], d[6]) * -1.0
+
+ elif d[0] == 2:
+
+ self.roll = convert(d[1], d[2])
+ self.pitch = convert(d[3], d[4])
+ self.yaw = convert(d[5], d[6])
+
+ self._control = [
+ self.x,
+ self.y,
+ self.z,
+ self.roll,
+ self.pitch,
+ self.yaw,
+ ]
+ else:
+ ## default logic for all other spacemouse models
+
+ if d[0] == 1: ## readings from 6-DoF sensor
+ self.y = convert(d[1], d[2])
+ self.x = convert(d[3], d[4])
+ self.z = convert(d[5], d[6]) * -1.0
+
+ self.roll = convert(d[7], d[8])
+ self.pitch = convert(d[9], d[10])
+ self.yaw = convert(d[11], d[12])
+
+ self._control = [
+ self.x,
+ self.y,
+ self.z,
+ self.roll,
+ self.pitch,
+ self.yaw,
+ ]
+
+ if d[0] == 3: ## readings from the side buttons
+
+ # press left button
+ if d[1] == 1:
+ t_click = time.time()
+ elapsed_time = t_click - t_last_click
+ t_last_click = t_click
+ self.single_click_and_hold = True
+
+ # release left button
+ if d[1] == 0:
+ self.single_click_and_hold = False
+
+ # right button is for reset
+ if d[1] == 2:
+ self._reset_state = 1
+ self._enabled = False
+ self._reset_internal_state()
+
+ @property
+ def control(self):
+ """
+ Grabs current pose of Spacemouse
+
+ Returns:
+ np.array: 6-DoF control value
+ """
+ return np.array(self._control)
+
+ @property
+ def control_gripper(self):
+ """
+ Maps internal states into gripper commands.
+
+ Returns:
+ float: Whether we're using single click and hold or not
+ """
+ if self.single_click_and_hold:
+ return 1.0
+ return 0
+
+
+if __name__ == "__main__":
+
+ space_mouse = SpaceMouse()
+ for i in range(100):
+ print(space_mouse.control, space_mouse.control_gripper)
+ time.sleep(0.02)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/environments/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fff6081f09064994187ac1154f62c57066bd455c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/__init__.py
@@ -0,0 +1,3 @@
+from .base import REGISTERED_ENVS, MujocoEnv
+
+ALL_ENVIRONMENTS = REGISTERED_ENVS.keys()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/base.py b/phantom/submodules/phantom-robosuite/robosuite/environments/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..62752b2fa9a13b8d3c9ab74e54b9aa2ee5a550a5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/base.py
@@ -0,0 +1,737 @@
+import os
+import xml.etree.ElementTree as ET
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite
+import robosuite.macros as macros
+import robosuite.utils.sim_utils as SU
+from robosuite.renderers.base import load_renderer_config
+from robosuite.utils import OpenCVRenderer, SimulationError, XMLError
+from robosuite.utils.binding_utils import MjRenderContextOffscreen, MjSim
+
+REGISTERED_ENVS = {}
+
+
+def register_env(target_class):
+ REGISTERED_ENVS[target_class.__name__] = target_class
+
+
+def make(env_name, *args, **kwargs):
+ """
+ Instantiates a robosuite environment.
+ This method attempts to mirror the equivalent functionality of gym.make in a somewhat sloppy way.
+ Args:
+ env_name (str): Name of the robosuite environment to initialize
+ *args: Additional arguments to pass to the specific environment class initializer
+ **kwargs: Additional arguments to pass to the specific environment class initializer
+ Returns:
+ MujocoEnv: Desired robosuite environment
+ Raises:
+ Exception: [Invalid environment name]
+ """
+ if env_name not in REGISTERED_ENVS:
+ raise Exception(
+ "Environment {} not found. Make sure it is a registered environment among: {}".format(
+ env_name, ", ".join(REGISTERED_ENVS)
+ )
+ )
+ return REGISTERED_ENVS[env_name](*args, **kwargs)
+
+
+class EnvMeta(type):
+ """Metaclass for registering environments"""
+
+ def __new__(meta, name, bases, class_dict):
+ cls = super().__new__(meta, name, bases, class_dict)
+
+ # List all environments that should not be registered here.
+ _unregistered_envs = ["MujocoEnv", "RobotEnv", "ManipulationEnv", "SingleArmEnv", "TwoArmEnv"]
+
+ if cls.__name__ not in _unregistered_envs:
+ register_env(cls)
+ return cls
+
+
+class MujocoEnv(metaclass=EnvMeta):
+ """
+ Initializes a Mujoco Environment.
+ Args:
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+ has_offscreen_renderer (bool): True if using off-screen rendering.
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+ render_collision_mesh (bool): True if rendering collision meshes
+ in camera. False otherwise.
+ render_visual_mesh (bool): True if rendering visual meshes
+ in camera. False otherwise.
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+ control_freq (float): how many control signals to receive
+ in every simulated second. This sets the amount of simulation time
+ that passes between every action input.
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+ renderer (str): string for the renderer to use
+ renderer_config (dict): dictionary for the renderer configurations
+ Raises:
+ ValueError: [Invalid renderer selection]
+ """
+
+ def __init__(
+ self,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # If you're using an onscreen renderer, you must be also using an offscreen renderer!
+ if has_renderer and not has_offscreen_renderer:
+ has_offscreen_renderer = True
+
+ # Rendering-specific attributes
+ self.has_renderer = has_renderer
+ # offscreen renderer needed for on-screen rendering
+ self.has_offscreen_renderer = has_renderer or has_offscreen_renderer
+ self.render_camera = render_camera
+ self.render_collision_mesh = render_collision_mesh
+ self.render_visual_mesh = render_visual_mesh
+ self.render_gpu_device_id = render_gpu_device_id
+ self.viewer = None
+
+ # Simulation-specific attributes
+ self._observables = {} # Maps observable names to observable objects
+ self._obs_cache = {} # Maps observable names to pre-/partially-computed observable values
+ self.control_freq = control_freq
+ self.horizon = horizon
+ self.ignore_done = ignore_done
+ self.hard_reset = hard_reset
+ self._xml_processor = None # Function to process model xml in _initialize_sim() call
+ self.model = None
+ self.cur_time = None
+ self.model_timestep = None
+ self.control_timestep = None
+ self.deterministic_reset = False # Whether to add randomized resetting of objects / robot joints
+
+ self.renderer = renderer
+ self.renderer_config = renderer_config
+
+ # Load the model
+ self._load_model()
+
+ # Initialize the simulation
+ self._initialize_sim()
+
+ # initializes the rendering
+ self.initialize_renderer()
+
+ # Run all further internal (re-)initialization required
+ self._reset_internal()
+
+ # Load observables
+ if hasattr(self.viewer, "_setup_observables"):
+ self._observables = self.viewer._setup_observables()
+ else:
+ self._observables = self._setup_observables()
+
+ # check if viewer has get observations method and set a flag for future use.
+ self.viewer_get_obs = hasattr(self.viewer, "_get_observations")
+
+ def initialize_renderer(self):
+ self.renderer = self.renderer.lower()
+
+ if self.renderer_config is None and self.renderer != "mujoco":
+ self.renderer_config = load_renderer_config(self.renderer)
+
+ if self.renderer == "mujoco" or self.renderer == "default":
+ pass
+ elif self.renderer == "nvisii":
+ from robosuite.renderers.nvisii.nvisii_renderer import NVISIIRenderer
+
+ self.viewer = NVISIIRenderer(env=self, **self.renderer_config)
+ else:
+ raise ValueError(
+ f"{self.renderer} is not a valid renderer name. Valid options include default (native mujoco renderer), and nvisii"
+ )
+
+ def initialize_time(self, control_freq):
+ """
+ Initializes the time constants used for simulation.
+ Args:
+ control_freq (float): Hz rate to run control loop at within the simulation
+ """
+ self.cur_time = 0
+ self.model_timestep = macros.SIMULATION_TIMESTEP
+ if self.model_timestep <= 0:
+ raise ValueError("Invalid simulation timestep defined!")
+ self.control_freq = control_freq
+ if control_freq <= 0:
+ raise SimulationError("Control frequency {} is invalid".format(control_freq))
+ self.control_timestep = 1.0 / control_freq
+
+ def set_xml_processor(self, processor):
+ """
+ Sets the processor function that xml string will be passed to inside _initialize_sim() calls.
+ Args:
+ processor (None or function): If set, processing method should take in a xml string and
+ return no arguments.
+ """
+ self._xml_processor = processor
+
+ def _load_model(self):
+ """Loads an xml model, puts it in self.model"""
+ pass
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ # Setup mappings from model to IDs
+ self.model.generate_id_mappings(sim=self.sim)
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment.
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ return OrderedDict()
+
+ def _initialize_sim(self, xml_string=None):
+ """
+ Creates a MjSim object and stores it in self.sim. If @xml_string is specified, the MjSim object will be created
+ from the specified xml_string. Else, it will pull from self.model to instantiate the simulation
+ Args:
+ xml_string (str): If specified, creates MjSim object from this filepath
+ """
+ xml = xml_string if xml_string else self.model.get_xml()
+
+ # process the xml before initializing sim
+ if self._xml_processor is not None:
+ xml = self._xml_processor(xml)
+
+ # Create the simulation instance
+ self.sim = MjSim.from_xml_string(xml)
+
+ # run a single step to make sure changes have propagated through sim state
+ self.sim.forward()
+
+ # Setup sim time based on control frequency
+ self.initialize_time(self.control_freq)
+
+ def reset(self):
+ """
+ Resets simulation.
+ Returns:
+ OrderedDict: Environment observation space after reset occurs
+ """
+ # TODO(yukez): investigate black screen of death
+ # Use hard reset if requested
+
+ if self.hard_reset and not self.deterministic_reset:
+ if self.renderer == "mujoco" or self.renderer == "default":
+ self._destroy_viewer()
+ self._destroy_sim()
+ self._load_model()
+ self._initialize_sim()
+ # Else, we only reset the sim internally
+ else:
+ self.sim.reset()
+
+ # Reset necessary robosuite-centric variables
+ self._reset_internal()
+ self.sim.forward()
+ # Setup observables, reloading if
+ self._obs_cache = {}
+ if self.hard_reset:
+ # If we're using hard reset, must re-update sensor object references
+ if hasattr(self.viewer, "_setup_observables"):
+ _observables = self.viewer._setup_observables()
+ else:
+ _observables = self._setup_observables()
+ for obs_name, obs in _observables.items():
+ self.modify_observable(observable_name=obs_name, attribute="sensor", modifier=obs._sensor)
+ # Make sure that all sites are toggled OFF by default
+ self.visualize(vis_settings={vis: False for vis in self._visualizations})
+
+ if self.viewer is not None and self.renderer != "mujoco":
+ self.viewer.reset()
+
+ observations = (
+ self.viewer._get_observations(force_update=True)
+ if self.viewer_get_obs
+ else self._get_observations(force_update=True)
+ )
+
+ # Return new observations
+ return observations
+
+ def _reset_internal(self):
+ """Resets simulation internal configurations."""
+
+ # create visualization screen or renderer
+ if self.has_renderer and self.viewer is None:
+ self.viewer = OpenCVRenderer(self.sim)
+
+ # Set the camera angle for viewing
+ if self.render_camera is not None:
+ camera_id = self.sim.model.camera_name2id(self.render_camera)
+ self.viewer.set_camera(camera_id)
+
+ if self.has_offscreen_renderer:
+ if self.sim._render_context_offscreen is None:
+ render_context = MjRenderContextOffscreen(self.sim, device_id=self.render_gpu_device_id)
+ self.sim._render_context_offscreen.vopt.geomgroup[0] = 1 if self.render_collision_mesh else 0
+ self.sim._render_context_offscreen.vopt.geomgroup[1] = 1 if self.render_visual_mesh else 0
+
+ # additional housekeeping
+ self.sim_state_initial = self.sim.get_state()
+ self._setup_references()
+ self.cur_time = 0
+ self.timestep = 0
+ self.done = False
+
+ # Empty observation cache and reset all observables
+ self._obs_cache = {}
+ for observable in self._observables.values():
+ observable.reset()
+
+ def _update_observables(self, force=False):
+ """
+ Updates all observables in this environment
+ Args:
+ force (bool): If True, will force all the observables to update their internal values to the newest
+ value. This is useful if, e.g., you want to grab observations when directly setting simulation states
+ without actually stepping the simulation.
+ """
+ for observable in self._observables.values():
+ observable.update(timestep=self.model_timestep, obs_cache=self._obs_cache, force=force)
+
+ def _get_observations(self, force_update=False):
+ """
+ Grabs observations from the environment.
+ Args:
+ force_update (bool): If True, will force all the observables to update their internal values to the newest
+ value. This is useful if, e.g., you want to grab observations when directly setting simulation states
+ without actually stepping the simulation.
+ Returns:
+ OrderedDict: OrderedDict containing observations [(name_string, np.array), ...]
+ """
+ observations = OrderedDict()
+ obs_by_modality = OrderedDict()
+
+ # Force an update if requested
+ if force_update:
+ self._update_observables(force=True)
+
+ # Loop through all observables and grab their current observation
+ for obs_name, observable in self._observables.items():
+ if observable.is_enabled() and observable.is_active():
+ obs = observable.obs
+ observations[obs_name] = obs
+ modality = observable.modality + "-state"
+ if modality not in obs_by_modality:
+ obs_by_modality[modality] = []
+ # Make sure all observations are numpy arrays so we can concatenate them
+ array_obs = [obs] if type(obs) in {int, float} or not obs.shape else obs
+ obs_by_modality[modality].append(np.array(array_obs))
+
+ # Add in modality observations
+ for modality, obs in obs_by_modality.items():
+ # To save memory, we only concatenate the image observations if explicitly requested
+ if modality == "image-state" and not macros.CONCATENATE_IMAGES:
+ continue
+ observations[modality] = np.concatenate(obs, axis=-1)
+
+ return observations
+
+ def step(self, action):
+ """
+ Takes a step in simulation with control command @action.
+ Args:
+ action (np.array): Action to execute within the environment
+ Returns:
+ 4-tuple:
+ - (OrderedDict) observations from the environment
+ - (float) reward from the environment
+ - (bool) whether the current episode is completed or not
+ - (dict) misc information
+ Raises:
+ ValueError: [Steps past episode termination]
+ """
+ if self.done:
+ raise ValueError("executing action in terminated episode")
+
+ self.timestep += 1
+
+ # Since the env.step frequency is slower than the mjsim timestep frequency, the internal controller will output
+ # multiple torque commands in between new high level action commands. Therefore, we need to denote via
+ # 'policy_step' whether the current step we're taking is simply an internal update of the controller,
+ # or an actual policy update
+ policy_step = True
+
+ # Loop through the simulation at the model timestep rate until we're ready to take the next policy step
+ # (as defined by the control frequency specified at the environment level)
+ for i in range(int(self.control_timestep / self.model_timestep)):
+ self.sim.forward()
+ self._pre_action(action, policy_step)
+ self.sim.step()
+ self._update_observables()
+ policy_step = False
+
+ # Note: this is done all at once to avoid floating point inaccuracies
+ self.cur_time += self.control_timestep
+
+ reward, done, info = self._post_action(action)
+
+ if self.viewer is not None and self.renderer != "mujoco":
+ self.viewer.update()
+
+ observations = self.viewer._get_observations() if self.viewer_get_obs else self._get_observations()
+ return observations, reward, done, info
+
+ def _pre_action(self, action, policy_step=False):
+ """
+ Do any preprocessing before taking an action.
+ Args:
+ action (np.array): Action to execute within the environment
+ policy_step (bool): Whether this current loop is an actual policy step or internal sim update step
+ """
+ self.sim.data.ctrl[:] = action
+
+ def _post_action(self, action):
+ """
+ Do any housekeeping after taking an action.
+ Args:
+ action (np.array): Action to execute within the environment
+ Returns:
+ 3-tuple:
+ - (float) reward from the environment
+ - (bool) whether the current episode is completed or not
+ - (dict) empty dict to be filled with information by subclassed method
+ """
+ reward = self.reward(action)
+
+ # done if number of elapsed timesteps is greater than horizon
+ self.done = (self.timestep >= self.horizon) and not self.ignore_done
+
+ return reward, self.done, {}
+
+ def reward(self, action):
+ """
+ Reward should be a function of state and action
+ Args:
+ action (np.array): Action to execute within the environment
+ Returns:
+ float: Reward from environment
+ """
+ raise NotImplementedError
+
+ def render(self):
+ """
+ Renders to an on-screen window.
+ """
+ self.viewer.render()
+
+ def get_pixel_obs(self):
+ """
+ Gets the pixel observations for the environment from the specified renderer
+ """
+ self.viewer.get_pixel_obs()
+
+ def close_renderer(self):
+ """
+ Closes the renderer
+ """
+ self.viewer.close()
+
+ def observation_spec(self):
+ """
+ Returns an observation as observation specification.
+ An alternative design is to return an OrderedDict where the keys
+ are the observation names and the values are the shapes of observations.
+ We leave this alternative implementation commented out, as we find the
+ current design is easier to use in practice.
+ Returns:
+ OrderedDict: Observations from the environment
+ """
+ observation = self.viewer._get_observations() if self.viewer_get_obs else self._get_observations()
+ return observation
+
+ def clear_objects(self, object_names):
+ """
+ Clears objects with the name @object_names out of the task space. This is useful
+ for supporting task modes with single types of objects, as in
+ @self.single_object_mode without changing the model definition.
+ Args:
+ object_names (str or list of str): Name of object(s) to remove from the task workspace
+ """
+ object_names = {object_names} if type(object_names) is str else set(object_names)
+ for obj in self.model.mujoco_objects:
+ if obj.name in object_names:
+ self.sim.data.set_joint_qpos(obj.joints[0], np.array((10, 10, 10, 1, 0, 0, 0)))
+
+ def visualize(self, vis_settings):
+ """
+ Do any needed visualization here
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "env" keyword as well as any other relevant
+ options specified.
+ """
+ # Set visuals for environment objects
+ for obj in self.model.mujoco_objects:
+ obj.set_sites_visibility(sim=self.sim, visible=vis_settings["env"])
+
+ def set_camera_pos_quat(self, camera_pos, camera_quat):
+ if self.renderer in ["nvisii"]:
+ self.viewer.set_camera_pos_quat(camera_pos, camera_quat)
+ else:
+ raise AttributeError("setting camera position and quat requires renderer to be NVISII.")
+
+ def edit_model_xml(self, xml_str):
+ """
+ This function edits the model xml with custom changes, including resolving relative paths,
+ applying changes retroactively to existing demonstration files, and other custom scripts.
+ Environment subclasses should modify this function to add environment-specific xml editing features.
+ Args:
+ xml_str (str): Mujoco sim demonstration XML file as string
+ Returns:
+ str: Edited xml file as string
+ """
+
+ path = os.path.split(robosuite.__file__)[0]
+ path_split = path.split("/")
+
+ # replace mesh and texture file paths
+ tree = ET.fromstring(xml_str)
+ root = tree
+ asset = root.find("asset")
+ meshes = asset.findall("mesh")
+ textures = asset.findall("texture")
+ all_elements = meshes + textures
+
+ for elem in all_elements:
+ old_path = elem.get("file")
+ if old_path is None:
+ continue
+ old_path_split = old_path.split("/")
+ ind = max(loc for loc, val in enumerate(old_path_split) if val == "robosuite") # last occurrence index
+ new_path_split = path_split + old_path_split[ind + 1 :]
+ new_path = "/".join(new_path_split)
+ elem.set("file", new_path)
+
+ return ET.tostring(root, encoding="utf8").decode("utf8")
+
+ def reset_from_xml_string(self, xml_string):
+ """
+ Reloads the environment from an XML description of the environment.
+ Args:
+ xml_string (str): Filepath to the xml file that will be loaded directly into the sim
+ """
+
+ # if there is an active viewer window, destroy it
+ if self.renderer != "nvisii":
+ self.close()
+
+ # Since we are reloading from an xml_string, we are deterministically resetting
+ self.deterministic_reset = True
+
+ # initialize sim from xml
+ self._initialize_sim(xml_string=xml_string)
+
+ # Now reset as normal
+ self.reset()
+
+ # Turn off deterministic reset
+ self.deterministic_reset = False
+
+ def check_contact(self, geoms_1, geoms_2=None):
+ """
+ Finds contact between two geom groups.
+ Args:
+ geoms_1 (str or list of str or MujocoModel): an individual geom name or list of geom names or a model. If
+ a MujocoModel is specified, the geoms checked will be its contact_geoms
+ geoms_2 (str or list of str or MujocoModel or None): another individual geom name or list of geom names.
+ If a MujocoModel is specified, the geoms checked will be its contact_geoms. If None, will check
+ any collision with @geoms_1 to any other geom in the environment
+ Returns:
+ bool: True if any geom in @geoms_1 is in contact with any geom in @geoms_2.
+ """
+ return SU.check_contact(sim=self.sim, geoms_1=geoms_1, geoms_2=geoms_2)
+
+ def get_contacts(self, model):
+ """
+ Checks for any contacts with @model (as defined by @model's contact_geoms) and returns the set of
+ geom names currently in contact with that model (excluding the geoms that are part of the model itself).
+ Args:
+ model (MujocoModel): Model to check contacts for.
+ Returns:
+ set: Unique geoms that are actively in contact with this model.
+ Raises:
+ AssertionError: [Invalid input type]
+ """
+ return SU.get_contacts(sim=self.sim, model=model)
+
+ def add_observable(self, observable):
+ """
+ Adds an observable to this environment.
+ Args:
+ observable (Observable): Observable instance.
+ """
+ assert observable.name not in self._observables, (
+ "Observable name {} is already associated with an existing observable! Use modify_observable(...) "
+ "to modify a pre-existing observable.".format(observable.name)
+ )
+ self._observables[observable.name] = observable
+
+ def modify_observable(self, observable_name, attribute, modifier):
+ """
+ Modifies observable with associated name @observable_name, replacing the given @attribute with @modifier.
+ Args:
+ observable_name (str): Observable to modify
+ attribute (str): Observable attribute to modify.
+ Options are {`'sensor'`, `'corrupter'`,`'filter'`, `'delayer'`, `'sampling_rate'`,
+ `'enabled'`, `'active'`}
+ modifier (any): New function / value to replace with for observable. If a function, new signature should
+ match the function being replaced.
+ """
+ # Find the observable
+ assert observable_name in self._observables, "No valid observable with name {} found. Options are: {}".format(
+ observable_name, self.observation_names
+ )
+ obs = self._observables[observable_name]
+ # replace attribute accordingly
+ if attribute == "sensor":
+ obs.set_sensor(modifier)
+ elif attribute == "corrupter":
+ obs.set_corrupter(modifier)
+ elif attribute == "filter":
+ obs.set_filter(modifier)
+ elif attribute == "delayer":
+ obs.set_delayer(modifier)
+ elif attribute == "sampling_rate":
+ obs.set_sampling_rate(modifier)
+ elif attribute == "enabled":
+ obs.set_enabled(modifier)
+ elif attribute == "active":
+ obs.set_active(modifier)
+ else:
+ # Invalid attribute specified
+ raise ValueError(
+ "Invalid observable attribute specified. Requested: {}, valid options are {}".format(
+ attribute, {"sensor", "corrupter", "filter", "delayer", "sampling_rate", "enabled", "active"}
+ )
+ )
+
+ def _check_success(self):
+ """
+ Checks if the task has been completed. Should be implemented by subclasses
+ Returns:
+ bool: True if the task has been completed
+ """
+ raise NotImplementedError
+
+ def _destroy_viewer(self):
+ """
+ Destroys the current mujoco renderer instance if it exists
+ """
+ # if there is an active viewer window, destroy it
+ if self.viewer is not None:
+ self.viewer.close() # change this to viewer.finish()?
+ self.viewer = None
+
+ def _destroy_sim(self):
+ """
+ Destroys the current MjSim instance if it exists
+ """
+ if self.sim is not None:
+ self.sim.free()
+ self.sim = None
+
+ def close(self):
+ """Do any cleanup necessary here."""
+ self._destroy_viewer()
+ self._destroy_sim()
+
+ @property
+ def observation_modalities(self):
+ """
+ Modalities for this environment's observations
+ Returns:
+ set: All observation modalities
+ """
+ return set([observable.modality for observable in self._observables.values()])
+
+ @property
+ def observation_names(self):
+ """
+ Grabs all names for this environment's observables
+ Returns:
+ set: All observation names
+ """
+ return set(self._observables.keys())
+
+ @property
+ def enabled_observables(self):
+ """
+ Grabs all names of enabled observables for this environment. An observable is considered enabled if its values
+ are being continually computed / updated at each simulation timestep.
+ Returns:
+ set: All enabled observation names
+ """
+ return set([name for name, observable in self._observables.items() if observable.is_enabled()])
+
+ @property
+ def active_observables(self):
+ """
+ Grabs all names of active observables for this environment. An observable is considered active if its value is
+ being returned in the observation dict from _get_observations() call or from the step() call (assuming this
+ observable is enabled).
+ Returns:
+ set: All active observation names
+ """
+ return set([name for name, observable in self._observables.items() if observable.is_active()])
+
+ @property
+ def _visualizations(self):
+ """
+ Visualization keywords for this environment
+ Returns:
+ set: All components that can be individually visualized for this environment
+ """
+ return {"env"}
+
+ @property
+ def action_spec(self):
+ """
+ Action specification should be implemented in subclasses.
+ Action space is represented by a tuple of (low, high), which are two numpy
+ vectors that specify the min/max action limits per dimension.
+ """
+ raise NotImplementedError
+
+ @property
+ def action_dim(self):
+ """
+ Size of the action space
+ Returns:
+ int: Action space dimension
+ """
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/__init__.py
@@ -0,0 +1 @@
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/door.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/door.py
new file mode 100644
index 0000000000000000000000000000000000000000..8953cf7ee69626285c7b241ca05dcb0b5e8c4ea5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/door.py
@@ -0,0 +1,461 @@
+from collections import OrderedDict
+
+import numpy as np
+
+from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
+from robosuite.models.arenas import TableArena
+from robosuite.models.objects import DoorObject
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import UniformRandomSampler
+
+
+class Door(SingleArmEnv):
+ """
+ This class corresponds to the door opening task for a single robot arm.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be a single single-arm robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment (default is "default").
+ For most single arm environments, this argument has no impact on the robot setup.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ use_latch (bool): if True, uses a spring-loaded handle and latch to "lock" the door closed initially
+ Otherwise, door is instantiated with a fixed handle
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ placement_initializer (ObjectPositionSampler): if provided, will
+ be used to place objects on every reset, else a UniformRandomSampler
+ is used by default.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ AssertionError: [Invalid number of robots specified]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ use_latch=True,
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ placement_initializer=None,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # settings for table top (hardcoded since it's not an essential part of the environment)
+ self.table_full_size = (0.8, 0.3, 0.05)
+ self.table_offset = (-0.2, -0.35, 0.8)
+
+ # reward configuration
+ self.use_latch = use_latch
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ # object placement initializer
+ self.placement_initializer = placement_initializer
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of 1.0 is provided if the door is opened
+
+ Un-normalized summed components if using reward shaping:
+
+ - Reaching: in [0, 0.25], proportional to the distance between door handle and robot arm
+ - Rotating: in [0, 0.25], proportional to angle rotated by door handled
+ - Note that this component is only relevant if the environment is using the locked door version
+
+ Note that a successfully completed task (door opened) will return 1.0 irregardless of whether the environment
+ is using sparse or shaped rewards
+
+ Note that the final reward is normalized and scaled by reward_scale / 1.0 as
+ well so that the max score is equal to reward_scale
+
+ Args:
+ action (np.array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ reward = 0.0
+
+ # sparse completion reward
+ if self._check_success():
+ reward = 1.0
+
+ # else, we consider only the case if we're using shaped rewards
+ elif self.reward_shaping:
+ # Add reaching component
+ dist = np.linalg.norm(self._gripper_to_handle)
+ reaching_reward = 0.25 * (1 - np.tanh(10.0 * dist))
+ reward += reaching_reward
+ # Add rotating component if we're using a locked door
+ if self.use_latch:
+ handle_qpos = self.sim.data.qpos[self.handle_qpos_addr]
+ reward += np.clip(0.25 * np.abs(handle_qpos / (0.5 * np.pi)), -0.25, 0.25)
+
+ # Scale reward if requested
+ if self.reward_scale is not None:
+ reward *= self.reward_scale / 1.0
+
+ return reward
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose accordingly
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = TableArena(
+ table_full_size=self.table_full_size,
+ table_offset=self.table_offset,
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # Modify default agentview camera
+ mujoco_arena.set_camera(
+ camera_name="agentview",
+ pos=[0.5986131746834771, -4.392035683362857e-09, 1.5903500240372423],
+ quat=[0.6380177736282349, 0.3048497438430786, 0.30484986305236816, 0.6380177736282349],
+ )
+
+ # initialize objects of interest
+ self.door = DoorObject(
+ name="Door",
+ friction=0.0,
+ damping=0.1,
+ lock=self.use_latch,
+ )
+
+ # Create placement initializer
+ if self.placement_initializer is not None:
+ self.placement_initializer.reset()
+ self.placement_initializer.add_objects(self.door)
+ else:
+ self.placement_initializer = UniformRandomSampler(
+ name="ObjectSampler",
+ mujoco_objects=self.door,
+ x_range=[0.07, 0.09],
+ y_range=[-0.01, 0.01],
+ rotation=(-np.pi / 2.0 - 0.25, -np.pi / 2.0),
+ rotation_axis="z",
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=True,
+ reference_pos=self.table_offset,
+ )
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ mujoco_objects=self.door,
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Additional object references from this env
+ self.object_body_ids = dict()
+ self.object_body_ids["door"] = self.sim.model.body_name2id(self.door.door_body)
+ self.object_body_ids["frame"] = self.sim.model.body_name2id(self.door.frame_body)
+ self.object_body_ids["latch"] = self.sim.model.body_name2id(self.door.latch_body)
+ self.door_handle_site_id = self.sim.model.site_name2id(self.door.important_sites["handle"])
+ self.hinge_qpos_addr = self.sim.model.get_joint_qpos_addr(self.door.joints[0])
+ if self.use_latch:
+ self.handle_qpos_addr = self.sim.model.get_joint_qpos_addr(self.door.joints[1])
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ pf = self.robots[0].robot_model.naming_prefix
+ modality = "object"
+
+ # Define sensor callbacks
+ @sensor(modality=modality)
+ def door_pos(obs_cache):
+ return np.array(self.sim.data.body_xpos[self.object_body_ids["door"]])
+
+ @sensor(modality=modality)
+ def handle_pos(obs_cache):
+ return self._handle_xpos
+
+ @sensor(modality=modality)
+ def door_to_eef_pos(obs_cache):
+ return (
+ obs_cache["door_pos"] - obs_cache[f"{pf}eef_pos"]
+ if "door_pos" in obs_cache and f"{pf}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ @sensor(modality=modality)
+ def handle_to_eef_pos(obs_cache):
+ return (
+ obs_cache["handle_pos"] - obs_cache[f"{pf}eef_pos"]
+ if "handle_pos" in obs_cache and f"{pf}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ @sensor(modality=modality)
+ def hinge_qpos(obs_cache):
+ return np.array([self.sim.data.qpos[self.hinge_qpos_addr]])
+
+ sensors = [door_pos, handle_pos, door_to_eef_pos, handle_to_eef_pos, hinge_qpos]
+ names = [s.__name__ for s in sensors]
+
+ # Also append handle qpos if we're using a locked door version with rotatable handle
+ if self.use_latch:
+
+ @sensor(modality=modality)
+ def handle_qpos(obs_cache):
+ return np.array([self.sim.data.qpos[self.handle_qpos_addr]])
+
+ sensors.append(handle_qpos)
+ names.append("handle_qpos")
+
+ # Create observables
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ # Reset all object positions using initializer sampler if we're not directly loading from an xml
+ if not self.deterministic_reset:
+
+ # Sample from the placement initializer for all objects
+ object_placements = self.placement_initializer.sample()
+
+ # We know we're only setting a single object (the door), so specifically set its pose
+ door_pos, door_quat, _ = object_placements[self.door.name]
+ door_body_id = self.sim.model.body_name2id(self.door.root_body)
+ self.sim.model.body_pos[door_body_id] = door_pos
+ self.sim.model.body_quat[door_body_id] = door_quat
+
+ def _check_success(self):
+ """
+ Check if door has been opened.
+
+ Returns:
+ bool: True if door has been opened
+ """
+ hinge_qpos = self.sim.data.qpos[self.hinge_qpos_addr]
+ return hinge_qpos > 0.3
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualize gripper site proportional to the distance to the door handle.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "grippers" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+
+ # Color the gripper visualization site according to its distance to the door handle
+ if vis_settings["grippers"]:
+ self._visualize_gripper_to_target(
+ gripper=self.robots[0].gripper, target=self.door.important_sites["handle"], target_type="site"
+ )
+
+ @property
+ def _handle_xpos(self):
+ """
+ Grabs the position of the door handle handle.
+
+ Returns:
+ np.array: Door handle (x,y,z)
+ """
+ return self.sim.data.site_xpos[self.door_handle_site_id]
+
+ @property
+ def _gripper_to_handle(self):
+ """
+ Calculates distance from the gripper to the door handle.
+
+ Returns:
+ np.array: (x,y,z) distance between handle and eef
+ """
+ return self._handle_xpos - self._eef_xpos
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/lift.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/lift.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d27b300f2e0f181b6ceb3bdb3ad7d76998909e0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/lift.py
@@ -0,0 +1,428 @@
+from collections import OrderedDict
+
+import numpy as np
+
+from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
+from robosuite.models.arenas import TableArena
+from robosuite.models.objects import BoxObject
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.mjcf_utils import CustomMaterial
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import UniformRandomSampler
+from robosuite.utils.transform_utils import convert_quat
+
+
+class Lift(SingleArmEnv):
+ """
+ This class corresponds to the lifting task for a single robot arm.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be a single single-arm robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment (default is "default").
+ For most single arm environments, this argument has no impact on the robot setup.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ table_full_size (3-tuple): x, y, and z dimensions of the table.
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ the table.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ placement_initializer (ObjectPositionSampler): if provided, will
+ be used to place objects on every reset, else a UniformRandomSampler
+ is used by default.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ AssertionError: [Invalid number of robots specified]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1.0, 5e-3, 1e-4),
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ placement_initializer=None,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # settings for table top
+ self.table_full_size = table_full_size
+ self.table_friction = table_friction
+ self.table_offset = np.array((0, 0, 0.8))
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ # object placement initializer
+ self.placement_initializer = placement_initializer
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of 2.25 is provided if the cube is lifted
+
+ Un-normalized summed components if using reward shaping:
+
+ - Reaching: in [0, 1], to encourage the arm to reach the cube
+ - Grasping: in {0, 0.25}, non-zero if arm is grasping the cube
+ - Lifting: in {0, 1}, non-zero if arm has lifted the cube
+
+ The sparse reward only consists of the lifting component.
+
+ Note that the final reward is normalized and scaled by
+ reward_scale / 2.25 as well so that the max score is equal to reward_scale
+
+ Args:
+ action (np array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ reward = 0.0
+
+ # sparse completion reward
+ if self._check_success():
+ reward = 2.25
+
+ # use a shaping reward
+ elif self.reward_shaping:
+
+ # reaching reward
+ cube_pos = self.sim.data.body_xpos[self.cube_body_id]
+ gripper_site_pos = self.sim.data.site_xpos[self.robots[0].eef_site_id]
+ dist = np.linalg.norm(gripper_site_pos - cube_pos)
+ reaching_reward = 1 - np.tanh(10.0 * dist)
+ reward += reaching_reward
+
+ # grasping reward
+ if self._check_grasp(gripper=self.robots[0].gripper, object_geoms=self.cube):
+ reward += 0.25
+
+ # Scale reward if requested
+ if self.reward_scale is not None:
+ reward *= self.reward_scale / 2.25
+
+ return reward
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose accordingly
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = TableArena(
+ table_full_size=self.table_full_size,
+ table_friction=self.table_friction,
+ table_offset=self.table_offset,
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # initialize objects of interest
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "1 1",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ redwood = CustomMaterial(
+ texture="WoodRed",
+ tex_name="redwood",
+ mat_name="redwood_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.cube = BoxObject(
+ name="cube",
+ size_min=[0.020, 0.020, 0.020], # [0.015, 0.015, 0.015],
+ size_max=[0.022, 0.022, 0.022], # [0.018, 0.018, 0.018])
+ rgba=[1, 0, 0, 1],
+ material=redwood,
+ )
+
+ # Create placement initializer
+ if self.placement_initializer is not None:
+ self.placement_initializer.reset()
+ self.placement_initializer.add_objects(self.cube)
+ else:
+ self.placement_initializer = UniformRandomSampler(
+ name="ObjectSampler",
+ mujoco_objects=self.cube,
+ x_range=[-0.03, 0.03],
+ y_range=[-0.03, 0.03],
+ rotation=None,
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=True,
+ reference_pos=self.table_offset,
+ z_offset=0.01,
+ )
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ mujoco_objects=self.cube,
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Additional object references from this env
+ self.cube_body_id = self.sim.model.body_name2id(self.cube.root_body)
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ pf = self.robots[0].robot_model.naming_prefix
+ modality = "object"
+
+ # cube-related observables
+ @sensor(modality=modality)
+ def cube_pos(obs_cache):
+ return np.array(self.sim.data.body_xpos[self.cube_body_id])
+
+ @sensor(modality=modality)
+ def cube_quat(obs_cache):
+ return convert_quat(np.array(self.sim.data.body_xquat[self.cube_body_id]), to="xyzw")
+
+ @sensor(modality=modality)
+ def gripper_to_cube_pos(obs_cache):
+ return (
+ obs_cache[f"{pf}eef_pos"] - obs_cache["cube_pos"]
+ if f"{pf}eef_pos" in obs_cache and "cube_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ sensors = [cube_pos, cube_quat, gripper_to_cube_pos]
+ names = [s.__name__ for s in sensors]
+
+ # Create observables
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ # Reset all object positions using initializer sampler if we're not directly loading from an xml
+ if not self.deterministic_reset:
+
+ # Sample from the placement initializer for all objects
+ object_placements = self.placement_initializer.sample()
+
+ # Loop through all objects and reset their positions
+ for obj_pos, obj_quat, obj in object_placements.values():
+ self.sim.data.set_joint_qpos(obj.joints[0], np.concatenate([np.array(obj_pos), np.array(obj_quat)]))
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualize gripper site proportional to the distance to the cube.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "grippers" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+
+ # Color the gripper visualization site according to its distance to the cube
+ if vis_settings["grippers"]:
+ self._visualize_gripper_to_target(gripper=self.robots[0].gripper, target=self.cube)
+
+ def _check_success(self):
+ """
+ Check if cube has been lifted.
+
+ Returns:
+ bool: True if cube has been lifted
+ """
+ cube_height = self.sim.data.body_xpos[self.cube_body_id][2]
+ table_height = self.model.mujoco_arena.table_offset[2]
+
+ # cube is higher than the table top above a margin
+ return cube_height > table_height + 0.04
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/manipulation_env.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/manipulation_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c37a03c0e097a441f31b9a141266f7aa601cbd6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/manipulation_env.py
@@ -0,0 +1,322 @@
+import numpy as np
+
+from robosuite.environments.robot_env import RobotEnv
+from robosuite.models.base import MujocoModel
+from robosuite.models.grippers import GripperModel
+from robosuite.robots import ROBOT_CLASS_MAPPING, Manipulator
+
+
+class ManipulationEnv(RobotEnv):
+ """
+ Initializes a manipulation-specific robot environment in Mujoco.
+
+ Args:
+ robots: Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+
+ env_configuration (str): Specifies how to position the robot(s) within the environment. Default is "default",
+ which should be interpreted accordingly by any subclasses.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ mount_types (None or str or list of str): type of mount, used to instantiate mount models from mount factory.
+ Default is "default", which is the default mount associated with the robot(s) the 'robots' specification.
+ None results in no mount, and any other (valid) model overrides the default mount. Should either be
+ single str if same mount type is to be used for all robots or else it should be a list of the same
+ length as "robots" param
+
+ gripper_types (None or str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ ValueError: [Camera obs require offscreen renderer]
+ ValueError: [Camera name must be specified to use camera obs]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ mount_types="default",
+ gripper_types="default",
+ initialization_noise=None,
+ use_camera_obs=True,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None,
+ renderer="mujoco",
+ renderer_config=None,
+ direct_gripper_control=False,
+ ):
+ # Robot info
+ robots = list(robots) if type(robots) is list or type(robots) is tuple else [robots]
+ num_robots = len(robots)
+
+ # Gripper
+ gripper_types = self._input2list(gripper_types, num_robots)
+
+ # Robot configurations to pass to super call
+ robot_configs = [
+ {
+ "gripper_type": gripper_types[idx],
+ }
+ for idx in range(num_robots)
+ ]
+
+ # Run superclass init
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types=mount_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ robot_configs=robot_configs,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ direct_gripper_control=direct_gripper_control,
+ )
+
+ @property
+ def _visualizations(self):
+ """
+ Visualization keywords for this environment
+
+ Returns:
+ set: All components that can be individually visualized for this environment
+ """
+ vis_set = super()._visualizations
+ vis_set.add("grippers")
+ return vis_set
+
+ def _check_grasp(self, gripper, object_geoms):
+ """
+ Checks whether the specified gripper as defined by @gripper is grasping the specified object in the environment.
+
+ By default, this will return True if at least one geom in both the "left_fingerpad" and "right_fingerpad" geom
+ groups are in contact with any geom specified by @object_geoms. Custom gripper geom groups can be
+ specified with @gripper as well.
+
+ Args:
+ gripper (GripperModel or str or list of str or list of list of str): If a MujocoModel, this is specific
+ gripper to check for grasping (as defined by "left_fingerpad" and "right_fingerpad" geom groups). Otherwise,
+ this sets custom gripper geom groups which together define a grasp. This can be a string
+ (one group of single gripper geom), a list of string (multiple groups of single gripper geoms) or a
+ list of list of string (multiple groups of multiple gripper geoms). At least one geom from each group
+ must be in contact with any geom in @object_geoms for this method to return True.
+ object_geoms (str or list of str or MujocoModel): If a MujocoModel is inputted, will check for any
+ collisions with the model's contact_geoms. Otherwise, this should be specific geom name(s) composing
+ the object to check for contact.
+
+ Returns:
+ bool: True if the gripper is grasping the given object
+ """
+ # Convert object, gripper geoms into standardized form
+ if isinstance(object_geoms, MujocoModel):
+ o_geoms = object_geoms.contact_geoms
+ else:
+ o_geoms = [object_geoms] if type(object_geoms) is str else object_geoms
+ if isinstance(gripper, GripperModel):
+ g_geoms = [gripper.important_geoms["left_fingerpad"], gripper.important_geoms["right_fingerpad"]]
+ elif type(gripper) is str:
+ g_geoms = [[gripper]]
+ else:
+ # Parse each element in the gripper_geoms list accordingly
+ g_geoms = [[g_group] if type(g_group) is str else g_group for g_group in gripper]
+
+ # Search for collisions between each gripper geom group and the object geoms group
+ for g_group in g_geoms:
+ if not self.check_contact(g_group, o_geoms):
+ return False
+ return True
+
+ def _gripper_to_target(self, gripper, target, target_type="body", return_distance=False):
+ """
+ Calculates the (x,y,z) Cartesian distance (target_pos - gripper_pos) from the specified @gripper to the
+ specified @target. If @return_distance is set, will return the Euclidean (scalar) distance instead.
+
+ Args:
+ gripper (MujocoModel): Gripper model to update grip site rgb
+ target (MujocoModel or str): Either a site / geom / body name, or a model that serves as the target.
+ If a model is given, then the root body will be used as the target.
+ target_type (str): One of {"body", "geom", or "site"}, corresponding to the type of element @target
+ refers to.
+ return_distance (bool): If set, will return Euclidean distance instead of Cartesian distance
+
+ Returns:
+ np.array or float: (Cartesian or Euclidean) distance from gripper to target
+ """
+ # Get gripper and target positions
+ gripper_pos = self.sim.data.get_site_xpos(gripper.important_sites["grip_site"])
+ # If target is MujocoModel, grab the correct body as the target and find the target position
+ if isinstance(target, MujocoModel):
+ target_pos = self.sim.data.get_body_xpos(target.root_body)
+ elif target_type == "body":
+ target_pos = self.sim.data.get_body_xpos(target)
+ elif target_type == "site":
+ target_pos = self.sim.data.get_site_xpos(target)
+ else:
+ target_pos = self.sim.data.get_geom_xpos(target)
+ # Calculate distance
+ diff = target_pos - gripper_pos
+ # Return appropriate value
+ return np.linalg.norm(diff) if return_distance else diff
+
+ def _visualize_gripper_to_target(self, gripper, target, target_type="body"):
+ """
+ Colors the grip visualization site proportional to the Euclidean distance to the specified @target.
+ Colors go from red --> green as the gripper gets closer.
+
+ Args:
+ gripper (MujocoModel): Gripper model to update grip site rgb
+ target (MujocoModel or str): Either a site / geom / body name, or a model that serves as the target.
+ If a model is given, then the root body will be used as the target.
+ target_type (str): One of {"body", "geom", or "site"}, corresponding to the type of element @target
+ refers to.
+ """
+ # Get gripper and target positions
+ gripper_pos = self.sim.data.get_site_xpos(gripper.important_sites["grip_site"])
+ # If target is MujocoModel, grab the correct body as the target and find the target position
+ if isinstance(target, MujocoModel):
+ target_pos = self.sim.data.get_body_xpos(target.root_body)
+ elif target_type == "body":
+ target_pos = self.sim.data.get_body_xpos(target)
+ elif target_type == "site":
+ target_pos = self.sim.data.get_site_xpos(target)
+ else:
+ target_pos = self.sim.data.get_geom_xpos(target)
+ # color the gripper site appropriately based on (squared) distance to target
+ dist = np.sum(np.square((target_pos - gripper_pos)))
+ max_dist = 0.1
+ scaled = (1.0 - min(dist / max_dist, 1.0)) ** 15
+ rgba = np.zeros(3)
+ rgba[0] = 1 - scaled
+ rgba[1] = scaled
+ self.sim.model.site_rgba[self.sim.model.site_name2id(gripper.important_sites["grip_site"])][:3] = rgba
+
+ def _check_robot_configuration(self, robots):
+ """
+ Sanity check to make sure inputted robots and the corresponding requested task/configuration combo is legal.
+ Should be implemented in every specific task module
+
+ Args:
+ robots (str or list of str): Inputted requested robots at the task-level environment
+ """
+ # Make sure all inputted robots are a manipulation robot
+ if type(robots) is str:
+ robots = [robots]
+ for robot in robots:
+ assert issubclass(
+ ROBOT_CLASS_MAPPING[robot], Manipulator
+ ), "Only manipulator robots supported for manipulation environment!"
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/nut_assembly.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/nut_assembly.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c0583c8ae2854f328191f78e4d16130521c8f94
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/nut_assembly.py
@@ -0,0 +1,708 @@
+import random
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
+from robosuite.models.arenas import PegsArena
+from robosuite.models.objects import RoundNutObject, SquareNutObject
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import SequentialCompositeSampler, UniformRandomSampler
+
+
+class NutAssembly(SingleArmEnv):
+ """
+ This class corresponds to the nut assembly task for a single robot arm.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be a single single-arm robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment (default is "default").
+ For most single arm environments, this argument has no impact on the robot setup.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ table_full_size (3-tuple): x, y, and z dimensions of the table.
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ the table.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ placement_initializer (ObjectPositionSampler): if provided, will
+ be used to place objects on every reset, else a UniformRandomSampler
+ is used by default.
+
+ single_object_mode (int): specifies which version of the task to do. Note that
+ the observations change accordingly.
+
+ :`0`: corresponds to the full task with both types of nuts.
+
+ :`1`: corresponds to an easier task with only one type of nut initialized
+ on the table with every reset. The type is randomized on every reset.
+
+ :`2`: corresponds to an easier task with only one type of nut initialized
+ on the table with every reset. The type is kept constant and will not
+ change between resets.
+
+ nut_type (string): if provided, should be either "round" or "square". Determines
+ which type of nut (round or square) will be spawned on every environment
+ reset. Only used if @single_object_mode is 2.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ AssertionError: [Invalid nut type specified]
+ AssertionError: [Invalid number of robots specified]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1, 0.005, 0.0001),
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ placement_initializer=None,
+ single_object_mode=0,
+ nut_type=None,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # task settings
+ self.single_object_mode = single_object_mode
+ self.nut_to_id = {"square": 0, "round": 1}
+ self.nut_id_to_sensors = {} # Maps nut id to sensor names for that nut
+ if nut_type is not None:
+ assert nut_type in self.nut_to_id.keys(), "invalid @nut_type argument - choose one of {}".format(
+ list(self.nut_to_id.keys())
+ )
+ self.nut_id = self.nut_to_id[nut_type] # use for convenient indexing
+ self.obj_to_use = None
+
+ # settings for table top
+ self.table_full_size = table_full_size
+ self.table_friction = table_friction
+ self.table_offset = np.array((0, 0, 0.82))
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ # object placement initializer
+ self.placement_initializer = placement_initializer
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of 1.0 per nut if it is placed around its correct peg
+
+ Un-normalized components if using reward shaping, where the maximum is returned if not solved:
+
+ - Reaching: in [0, 0.1], proportional to the distance between the gripper and the closest nut
+ - Grasping: in {0, 0.35}, nonzero if the gripper is grasping a nut
+ - Lifting: in {0, [0.35, 0.5]}, nonzero only if nut is grasped; proportional to lifting height
+ - Hovering: in {0, [0.5, 0.7]}, nonzero only if nut is lifted; proportional to distance from nut to peg
+
+ Note that a successfully completed task (nut around peg) will return 1.0 per nut irregardless of whether the
+ environment is using sparse or shaped rewards
+
+ Note that the final reward is normalized and scaled by reward_scale / 2.0 (or 1.0 if only a single nut is
+ being used) as well so that the max score is equal to reward_scale
+
+ Args:
+ action (np.array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ # compute sparse rewards
+ self._check_success()
+ reward = np.sum(self.objects_on_pegs)
+
+ # add in shaped rewards
+ if self.reward_shaping:
+ staged_rewards = self.staged_rewards()
+ reward += max(staged_rewards)
+ if self.reward_scale is not None:
+ reward *= self.reward_scale
+ if self.single_object_mode == 0:
+ reward /= 2.0
+ return reward
+
+ def staged_rewards(self):
+ """
+ Calculates staged rewards based on current physical states.
+ Stages consist of reaching, grasping, lifting, and hovering.
+
+ Returns:
+ 4-tuple:
+
+ - (float) reaching reward
+ - (float) grasping reward
+ - (float) lifting reward
+ - (float) hovering reward
+ """
+
+ reach_mult = 0.1
+ grasp_mult = 0.35
+ lift_mult = 0.5
+ hover_mult = 0.7
+
+ # filter out objects that are already on the correct pegs
+ active_nuts = []
+ for i, nut in enumerate(self.nuts):
+ if self.objects_on_pegs[i]:
+ continue
+ active_nuts.append(nut)
+
+ # reaching reward governed by distance to closest object
+ r_reach = 0.0
+ if active_nuts:
+ # reaching reward via minimum distance to the handles of the objects
+ dists = [
+ self._gripper_to_target(
+ gripper=self.robots[0].gripper,
+ target=active_nut.important_sites["handle"],
+ target_type="site",
+ return_distance=True,
+ )
+ for active_nut in active_nuts
+ ]
+ r_reach = (1 - np.tanh(10.0 * min(dists))) * reach_mult
+
+ # grasping reward for touching any objects of interest
+ r_grasp = (
+ int(
+ self._check_grasp(
+ gripper=self.robots[0].gripper,
+ object_geoms=[g for active_nut in active_nuts for g in active_nut.contact_geoms],
+ )
+ )
+ * grasp_mult
+ )
+
+ # lifting reward for picking up an object
+ r_lift = 0.0
+ table_pos = np.array(self.sim.data.body_xpos[self.table_body_id])
+ if active_nuts and r_grasp > 0.0:
+ z_target = table_pos[2] + 0.2
+ object_z_locs = self.sim.data.body_xpos[[self.obj_body_id[active_nut.name] for active_nut in active_nuts]][
+ :, 2
+ ]
+ z_dists = np.maximum(z_target - object_z_locs, 0.0)
+ r_lift = grasp_mult + (1 - np.tanh(15.0 * min(z_dists))) * (lift_mult - grasp_mult)
+
+ # hover reward for getting object above peg
+ r_hover = 0.0
+ if active_nuts:
+ r_hovers = np.zeros(len(active_nuts))
+ peg_body_ids = [self.peg1_body_id, self.peg2_body_id]
+ for i, nut in enumerate(active_nuts):
+ valid_obj = False
+ peg_pos = None
+ for nut_name, idn in self.nut_to_id.items():
+ if nut_name in nut.name.lower():
+ peg_pos = np.array(self.sim.data.body_xpos[peg_body_ids[idn]])[:2]
+ valid_obj = True
+ break
+ if not valid_obj:
+ raise Exception("Got invalid object to reach: {}".format(nut.name))
+ ob_xy = self.sim.data.body_xpos[self.obj_body_id[nut.name]][:2]
+ dist = np.linalg.norm(peg_pos - ob_xy)
+ r_hovers[i] = r_lift + (1 - np.tanh(10.0 * dist)) * (hover_mult - lift_mult)
+ r_hover = np.max(r_hovers)
+
+ return r_reach, r_grasp, r_lift, r_hover
+
+ def on_peg(self, obj_pos, peg_id):
+
+ if peg_id == 0:
+ peg_pos = np.array(self.sim.data.body_xpos[self.peg1_body_id])
+ else:
+ peg_pos = np.array(self.sim.data.body_xpos[self.peg2_body_id])
+ res = False
+ if (
+ abs(obj_pos[0] - peg_pos[0]) < 0.03
+ and abs(obj_pos[1] - peg_pos[1]) < 0.03
+ and obj_pos[2] < self.table_offset[2] + 0.05
+ ):
+ res = True
+ return res
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose accordingly
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = PegsArena(
+ table_full_size=self.table_full_size,
+ table_friction=self.table_friction,
+ table_offset=self.table_offset,
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # define nuts
+ self.nuts = []
+ nut_names = ("SquareNut", "RoundNut")
+
+ # Create default (SequentialCompositeSampler) sampler if it has not already been specified
+ if self.placement_initializer is None:
+ self.placement_initializer = SequentialCompositeSampler(name="ObjectSampler")
+ for nut_name, default_y_range in zip(nut_names, ([0.11, 0.225], [-0.225, -0.11])):
+ self.placement_initializer.append_sampler(
+ sampler=UniformRandomSampler(
+ name=f"{nut_name}Sampler",
+ x_range=[-0.115, -0.11],
+ y_range=default_y_range,
+ rotation=None,
+ rotation_axis="z",
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=True,
+ reference_pos=self.table_offset,
+ z_offset=0.02,
+ )
+ )
+ # Reset sampler before adding any new samplers / objects
+ self.placement_initializer.reset()
+
+ for i, (nut_cls, nut_name) in enumerate(
+ zip(
+ (SquareNutObject, RoundNutObject),
+ nut_names,
+ )
+ ):
+ nut = nut_cls(name=nut_name)
+ self.nuts.append(nut)
+ # Add this nut to the placement initializer
+ if isinstance(self.placement_initializer, SequentialCompositeSampler):
+ # assumes we have two samplers so we add nuts to them
+ self.placement_initializer.add_objects_to_sampler(sampler_name=f"{nut_name}Sampler", mujoco_objects=nut)
+ else:
+ # This is assumed to be a flat sampler, so we just add all nuts to this sampler
+ self.placement_initializer.add_objects(nut)
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ mujoco_objects=self.nuts,
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Additional object references from this env
+ self.obj_body_id = {}
+ self.obj_geom_id = {}
+
+ self.table_body_id = self.sim.model.body_name2id("table")
+ self.peg1_body_id = self.sim.model.body_name2id("peg1")
+ self.peg2_body_id = self.sim.model.body_name2id("peg2")
+
+ for nut in self.nuts:
+ self.obj_body_id[nut.name] = self.sim.model.body_name2id(nut.root_body)
+ self.obj_geom_id[nut.name] = [self.sim.model.geom_name2id(g) for g in nut.contact_geoms]
+
+ # information of objects
+ self.object_site_ids = [self.sim.model.site_name2id(nut.important_sites["handle"]) for nut in self.nuts]
+
+ # keep track of which objects are on their corresponding pegs
+ self.objects_on_pegs = np.zeros(len(self.nuts))
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ pf = self.robots[0].robot_model.naming_prefix
+ modality = "object"
+
+ # Reset nut sensor mappings
+ self.nut_id_to_sensors = {}
+
+ # for conversion to relative gripper frame
+ @sensor(modality=modality)
+ def world_pose_in_gripper(obs_cache):
+ return (
+ T.pose_inv(T.pose2mat((obs_cache[f"{pf}eef_pos"], obs_cache[f"{pf}eef_quat"])))
+ if f"{pf}eef_pos" in obs_cache and f"{pf}eef_quat" in obs_cache
+ else np.eye(4)
+ )
+
+ sensors = [world_pose_in_gripper]
+ names = ["world_pose_in_gripper"]
+ enableds = [True]
+ actives = [False]
+
+ # Define nut related sensors
+ for i, nut in enumerate(self.nuts):
+ # Create sensors for this nut
+ using_nut = self.single_object_mode == 0 or self.nut_id == i
+ nut_sensors, nut_sensor_names = self._create_nut_sensors(nut_name=nut.name, modality=modality)
+ sensors += nut_sensors
+ names += nut_sensor_names
+ enableds += [using_nut] * 4
+ actives += [using_nut] * 4
+ self.nut_id_to_sensors[i] = nut_sensor_names
+
+ if self.single_object_mode == 1:
+ # This is randomly sampled object, so we need to include object id as observation
+ @sensor(modality=modality)
+ def nut_id(obs_cache):
+ return self.nut_id
+
+ sensors.append(nut_id)
+ names.append("nut_id")
+ enableds.append(True)
+ actives.append(True)
+
+ # Create observables
+ for name, s, enabled, active in zip(names, sensors, enableds, actives):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ enabled=enabled,
+ active=active,
+ )
+
+ return observables
+
+ def _create_nut_sensors(self, nut_name, modality="object"):
+ """
+ Helper function to create sensors for a given nut. This is abstracted in a separate function call so that we
+ don't have local function naming collisions during the _setup_observables() call.
+
+ Args:
+ nut_name (str): Name of nut to create sensors for
+ modality (str): Modality to assign to all sensors
+
+ Returns:
+ 2-tuple:
+ sensors (list): Array of sensors for the given nut
+ names (list): array of corresponding observable names
+ """
+ pf = self.robots[0].robot_model.naming_prefix
+
+ @sensor(modality=modality)
+ def nut_pos(obs_cache):
+ return np.array(self.sim.data.body_xpos[self.obj_body_id[nut_name]])
+
+ @sensor(modality=modality)
+ def nut_quat(obs_cache):
+ return T.convert_quat(self.sim.data.body_xquat[self.obj_body_id[nut_name]], to="xyzw")
+
+ @sensor(modality=modality)
+ def nut_to_eef_pos(obs_cache):
+ # Immediately return default value if cache is empty
+ if any(
+ [name not in obs_cache for name in [f"{nut_name}_pos", f"{nut_name}_quat", "world_pose_in_gripper"]]
+ ):
+ return np.zeros(3)
+ obj_pose = T.pose2mat((obs_cache[f"{nut_name}_pos"], obs_cache[f"{nut_name}_quat"]))
+ rel_pose = T.pose_in_A_to_pose_in_B(obj_pose, obs_cache["world_pose_in_gripper"])
+ rel_pos, rel_quat = T.mat2pose(rel_pose)
+ obs_cache[f"{nut_name}_to_{pf}eef_quat"] = rel_quat
+ return rel_pos
+
+ @sensor(modality=modality)
+ def nut_to_eef_quat(obs_cache):
+ return (
+ obs_cache[f"{nut_name}_to_{pf}eef_quat"] if f"{nut_name}_to_{pf}eef_quat" in obs_cache else np.zeros(4)
+ )
+
+ sensors = [nut_pos, nut_quat, nut_to_eef_pos, nut_to_eef_quat]
+ names = [f"{nut_name}_pos", f"{nut_name}_quat", f"{nut_name}_to_{pf}eef_pos", f"{nut_name}_to_{pf}eef_quat"]
+
+ return sensors, names
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ # Reset all object positions using initializer sampler if we're not directly loading from an xml
+ if not self.deterministic_reset:
+
+ # Sample from the placement initializer for all objects
+ object_placements = self.placement_initializer.sample()
+
+ # Loop through all objects and reset their positions
+ for obj_pos, obj_quat, obj in object_placements.values():
+ self.sim.data.set_joint_qpos(obj.joints[0], np.concatenate([np.array(obj_pos), np.array(obj_quat)]))
+
+ # Move objects out of the scene depending on the mode
+ nut_names = {nut.name for nut in self.nuts}
+ if self.single_object_mode == 1:
+ self.obj_to_use = random.choice(list(nut_names))
+ for nut_type, i in self.nut_to_id.items():
+ if nut_type.lower() in self.obj_to_use.lower():
+ self.nut_id = i
+ break
+ elif self.single_object_mode == 2:
+ self.obj_to_use = self.nuts[self.nut_id].name
+ if self.single_object_mode in {1, 2}:
+ nut_names.remove(self.obj_to_use)
+ self.clear_objects(list(nut_names))
+
+ # Make sure to update sensors' active and enabled states
+ if self.single_object_mode != 0:
+ for i, sensor_names in self.nut_id_to_sensors.items():
+ for name in sensor_names:
+ # Set all of these sensors to be enabled and active if this is the active nut, else False
+ self._observables[name].set_enabled(i == self.nut_id)
+ self._observables[name].set_active(i == self.nut_id)
+
+ def _check_success(self):
+ """
+ Check if all nuts have been successfully placed around their corresponding pegs.
+
+ Returns:
+ bool: True if all nuts are placed correctly
+ """
+ # remember objects that are on the correct pegs
+ gripper_site_pos = self.sim.data.site_xpos[self.robots[0].eef_site_id]
+ for i, nut in enumerate(self.nuts):
+ obj_str = nut.name
+ obj_pos = self.sim.data.body_xpos[self.obj_body_id[obj_str]]
+ dist = np.linalg.norm(gripper_site_pos - obj_pos)
+ r_reach = 1 - np.tanh(10.0 * dist)
+ self.objects_on_pegs[i] = int(self.on_peg(obj_pos, i) and r_reach < 0.6)
+
+ if self.single_object_mode > 0:
+ return np.sum(self.objects_on_pegs) > 0 # need one object on peg
+
+ # returns True if all objects are on correct pegs
+ return np.sum(self.objects_on_pegs) == len(self.nuts)
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualize gripper site proportional to the distance to the closest nut.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "grippers" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+
+ # Color the gripper visualization site according to its distance to the closest nut
+ if vis_settings["grippers"]:
+ # find closest object
+ dists = [
+ self._gripper_to_target(
+ gripper=self.robots[0].gripper,
+ target=nut.important_sites["handle"],
+ target_type="site",
+ return_distance=True,
+ )
+ for nut in self.nuts
+ ]
+ closest_nut_id = np.argmin(dists)
+ # Visualize the distance to this target
+ self._visualize_gripper_to_target(
+ gripper=self.robots[0].gripper,
+ target=self.nuts[closest_nut_id].important_sites["handle"],
+ target_type="site",
+ )
+
+
+class NutAssemblySingle(NutAssembly):
+ """
+ Easier version of task - place either one round nut or one square nut into its peg.
+ """
+
+ def __init__(self, **kwargs):
+ assert "single_object_mode" not in kwargs, "invalid set of arguments"
+ super().__init__(single_object_mode=1, **kwargs)
+
+
+class NutAssemblySquare(NutAssembly):
+ """
+ Easier version of task - place one square nut into its peg.
+ """
+
+ def __init__(self, **kwargs):
+ assert "single_object_mode" not in kwargs and "nut_type" not in kwargs, "invalid set of arguments"
+ super().__init__(single_object_mode=2, nut_type="square", **kwargs)
+
+
+class NutAssemblyRound(NutAssembly):
+ """
+ Easier version of task - place one round nut into its peg.
+ """
+
+ def __init__(self, **kwargs):
+ assert "single_object_mode" not in kwargs and "nut_type" not in kwargs, "invalid set of arguments"
+ super().__init__(single_object_mode=2, nut_type="round", **kwargs)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/phantom.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/phantom.py
new file mode 100644
index 0000000000000000000000000000000000000000..de1b86ddea9ecf665d432e5f0e7adb45c1140ee0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/phantom.py
@@ -0,0 +1,299 @@
+import numpy as np
+
+from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
+from robosuite.models.arenas import PhantomTableArena
+from robosuite.models.tasks import ManipulationTask
+
+
+class Phantom(SingleArmEnv):
+ """
+ This class corresponds to the stacking task for a single robot arm.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be a single single-arm robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment (default is "default").
+ For most single arm environments, this argument has no impact on the robot setup.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ table_full_size (3-tuple): x, y, and z dimensions of the table.
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ the table.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ placement_initializer (ObjectPositionSampler): if provided, will
+ be used to place objects on every reset, else a UniformRandomSampler
+ is used by default.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ AssertionError: [Invalid number of robots specified]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1.0, 5e-3, 1e-4),
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ placement_initializer=None,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="frontview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ object_placements=None,
+ direct_gripper_control=False,
+ camera_pos=None,
+ camera_quat_wxyz=None,
+ camera_fov=None,
+ camera_sensorsize=None,
+ camera_principalpixel=None,
+ camera_focalpixel=None,
+ ):
+
+ # settings for table top
+ self.table_full_size = table_full_size
+ self.table_friction = table_friction
+ self.table_offset = np.array((0, 0, 0.8))
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ # object placement initializer
+ self.placement_initializer = placement_initializer
+
+ self.object_placements = object_placements
+ self.camera_pos = camera_pos
+ self.camera_quat_wxyz = camera_quat_wxyz
+ self.camera_fov = camera_fov
+ self.camera_sensorsize = camera_sensorsize
+ self.camera_principalpixel = camera_principalpixel
+ self.camera_focalpixel = camera_focalpixel
+
+ # pdb.set_trace()
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ direct_gripper_control=direct_gripper_control,
+ )
+
+ def reset(self, object_placements=None):
+ self.object_placements = object_placements
+ return super().reset()
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose accordingly
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = PhantomTableArena(
+ table_full_size=self.table_full_size,
+ table_friction=self.table_friction,
+ table_offset=self.table_offset,
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ )
+
+ # Modify default frontview camera
+ if self.camera_pos is not None:
+ robot_base_pos = np.array([-0.56, 0, 0.912])
+ mujoco_arena.set_camera(
+ camera_name="frontview",
+ pos=self.camera_pos + robot_base_pos,
+ quat=self.camera_quat_wxyz,
+ camera_attribs={"sensorsize": np.array2string(self.camera_sensorsize)[1:-1],
+ "resolution": f"{self.camera_widths[0]} {self.camera_heights[0]}",
+ "principalpixel": np.array2string(self.camera_principalpixel)[1:-1],
+ "focalpixel": np.array2string(self.camera_focalpixel)[1:-1],}
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualize gripper site proportional to the distance to the cube.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "grippers" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+
+ # # Color the gripper visualization site according to its distance to the cube
+ # if vis_settings["grippers"]:
+ # self._visualize_gripper_to_target(gripper=self.robots[0].gripper, target=self.cubeA)
+
+ def reward(self, action):
+ return 0.0
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/phantom_bimanual.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/phantom_bimanual.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf095f15977ef1ba6b68bc759150f928596a57e9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/phantom_bimanual.py
@@ -0,0 +1,341 @@
+
+from collections import OrderedDict
+
+import numpy as np
+import pdb
+from scipy.spatial.transform import Rotation
+
+from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
+from robosuite.environments.manipulation.two_arm_env import TwoArmEnv
+# from robosuite.models.arenas import TableArena
+from robosuite.models.arenas import TableArena2, EmptyArena
+from robosuite.models.objects import BoxObject
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.mjcf_utils import CustomMaterial
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import UniformRandomSampler
+from robosuite.utils.transform_utils import convert_quat
+from robosuite.models.objects import BoxObject, CylinderObject
+
+class PhantomBimanual(TwoArmEnv):
+ """
+ This class corresponds to the stacking task for a single robot arm.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be a single single-arm robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment (default is "default").
+ For most single arm environments, this argument has no impact on the robot setup.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ table_full_size (3-tuple): x, y, and z dimensions of the table.
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ the table.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ placement_initializer (ObjectPositionSampler): if provided, will
+ be used to place objects on every reset, else a UniformRandomSampler
+ is used by default.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ AssertionError: [Invalid number of robots specified]
+ """
+
+ def __init__(
+ self,
+ robots,
+ bimanual_setup,
+ env_configuration="default",
+ controller_configs=None,
+ mount_types="default",
+ gripper_types="default",
+ initialization_noise="default",
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1.0, 5e-3, 1e-4),
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ placement_initializer=None,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="zed",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ object_placements=None,
+ direct_gripper_control=False,
+ camera_pos=None,
+ camera_quat_wxyz=None,
+ camera_fov=None,
+ camera_sensorsize=None,
+ camera_principalpixel=None,
+ camera_focalpixel=None,
+ ):
+
+ self.bimanual_setup = bimanual_setup
+ # settings for table top
+ self.table_full_size = table_full_size
+ self.table_friction = table_friction
+ self.table_offset = np.array((0, 0, 0.4))
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ # object placement initializer
+ self.placement_initializer = placement_initializer
+
+ self.object_placements = object_placements
+ self.camera_pos = camera_pos
+ self.camera_quat_wxyz = camera_quat_wxyz
+ self.camera_fov = camera_fov
+ self.camera_sensorsize = camera_sensorsize
+ self.camera_principalpixel = camera_principalpixel
+ self.camera_focalpixel = camera_focalpixel
+
+ self.robot_base_height = 2.0
+ self.robot_base_offset = -0.5
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types=mount_types,
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ direct_gripper_control=direct_gripper_control,
+ )
+
+ def reset(self, object_placements=None):
+ self.object_placements = object_placements
+ return super().reset()
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ if self.bimanual_setup == "tabletop":
+ count = 0
+ for robot, offset, rotation in zip(self.robots, (-0.2, 0.2), (0, 0)):
+ xpos = np.array((0, offset, self.robot_base_height))
+ robot.robot_model.set_base_xpos(xpos)
+ rot = np.array((rotation, 0, np.pi)) if count == 1 else np.array((rotation, 0, 0))
+ robot.robot_model.set_base_ori(rot)
+ count += 1
+ elif self.bimanual_setup == "shoulders1":
+ count = 0
+ for robot, offset, rotation in zip(self.robots, (-0.2, 0.2), (np.pi*2/3, -np.pi*2/3)):
+ xpos = np.array((0, offset, self.robot_base_height))
+ robot.robot_model.set_base_xpos(xpos)
+ rot = np.array((rotation, 0, np.pi)) if count == 1 else np.array((rotation, 0, 0))
+ robot.robot_model.set_base_ori(rot)
+ count += 1
+ elif self.bimanual_setup == "shoulders2":
+ count = 0
+ for robot, offset, rotation in zip(self.robots, (-0.2, 0.2), (np.pi/3, -np.pi/3)):
+ xpos = np.array((0, offset, self.robot_base_height))
+ robot.robot_model.set_base_xpos(xpos)
+ rot = np.array((rotation, 0, np.pi)) if count == 1 else np.array((rotation, 0, 0))
+ robot.robot_model.set_base_ori(rot)
+ count += 1
+ elif self.bimanual_setup == "shoulders":
+ count = 0
+ for robot, offset, rotation in zip(self.robots, (-0.2, 0.2), (np.pi/3, -np.pi/3)):
+ if count == 1:
+ xpos = np.array((0, 0.2, self.robot_base_height+self.robot_base_offset+robot.robot_model.bottom_offset[2]))
+ else:
+ xpos = np.array((-0.00656507, -0.14111039, 1.58980033+robot.robot_model.bottom_offset[2]))
+ robot.robot_model.set_base_xpos(xpos)
+ if count == 1:
+ rot = np.array((rotation, 0, np.pi/2))
+ else:
+ rot = np.array((0.50415113, -0.05164374, -1.57347674))
+ robot.robot_model.set_base_ori(rot)
+ count += 1
+
+ mujoco_arena = EmptyArena()
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ )
+
+ # Modify zed camera
+ if self.camera_pos is not None:
+
+ mujoco_arena.set_camera(
+ camera_name="zed",
+ pos=self.camera_pos,
+ quat=self.camera_quat_wxyz,
+ camera_attribs={"sensorsize": np.array2string(self.camera_sensorsize)[1:-1],
+ "resolution": f"{self.camera_widths[0]} {self.camera_heights[0]}",
+ "principalpixel": np.array2string(self.camera_principalpixel)[1:-1],
+ "focalpixel": np.array2string(self.camera_focalpixel)[1:-1],}
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualize gripper site proportional to the distance to the cube.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "grippers" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+
+ def reward(self, action):
+ return 0.0
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/pick_place.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/pick_place.py
new file mode 100644
index 0000000000000000000000000000000000000000..d69a718d83830a3d0d5db618779e49dae3c8d717
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/pick_place.py
@@ -0,0 +1,838 @@
+import random
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
+from robosuite.models.arenas import BinsArena
+from robosuite.models.objects import (
+ BreadObject,
+ BreadVisualObject,
+ CanObject,
+ CanVisualObject,
+ CerealObject,
+ CerealVisualObject,
+ MilkObject,
+ MilkVisualObject,
+)
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import SequentialCompositeSampler, UniformRandomSampler
+
+
+class PickPlace(SingleArmEnv):
+ """
+ This class corresponds to the pick place task for a single robot arm.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be a single single-arm robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment (default is "default").
+ For most single arm environments, this argument has no impact on the robot setup.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ table_full_size (3-tuple): x, y, and z dimensions of the table.
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ the table.
+
+ bin1_pos (3-tuple): Absolute cartesian coordinates of the bin initially holding the objects
+
+ bin2_pos (3-tuple): Absolute cartesian coordinates of the goal bin
+
+ z_offset (float): amount of z offset for initializing objects in bin
+
+ z_rotation (float, tuple, or None): if provided, controls the range of z-rotation initialization
+ for the objects
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ single_object_mode (int): specifies which version of the task to do. Note that
+ the observations change accordingly.
+
+ :`0`: corresponds to the full task with all types of objects.
+
+ :`1`: corresponds to an easier task with only one type of object initialized
+ on the table with every reset. The type is randomized on every reset.
+
+ :`2`: corresponds to an easier task with only one type of object initialized
+ on the table with every reset. The type is kept constant and will not
+ change between resets.
+
+ object_type (string): if provided, should be one of "milk", "bread", "cereal",
+ or "can". Determines which type of object will be spawned on every
+ environment reset. Only used if @single_object_mode is 2.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ AssertionError: [Invalid object type specified]
+ AssertionError: [Invalid number of robots specified]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ table_full_size=(0.39, 0.49, 0.82),
+ table_friction=(1, 0.005, 0.0001),
+ bin1_pos=(0.1, -0.25, 0.8),
+ bin2_pos=(0.1, 0.28, 0.8),
+ z_offset=0.,
+ z_rotation=None,
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ single_object_mode=0,
+ object_type=None,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # task settings
+ self.single_object_mode = single_object_mode
+ self.object_to_id = {"milk": 0, "bread": 1, "cereal": 2, "can": 3}
+ self.object_id_to_sensors = {} # Maps object id to sensor names for that object
+ self.obj_names = ["Milk", "Bread", "Cereal", "Can"]
+ if object_type is not None:
+ assert object_type in self.object_to_id.keys(), "invalid @object_type argument - choose one of {}".format(
+ list(self.object_to_id.keys())
+ )
+ self.object_id = self.object_to_id[object_type] # use for convenient indexing
+ self.obj_to_use = None
+
+ # settings for table top
+ self.table_full_size = table_full_size
+ self.table_friction = table_friction
+
+ # settings for bin position
+ self.bin1_pos = np.array(bin1_pos)
+ self.bin2_pos = np.array(bin2_pos)
+ self.z_offset = z_offset # z offset for initializing items in bin
+ self.z_rotation = z_rotation # z rotation for initializing items in bin
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of 1.0 per object if it is placed in its correct bin
+
+ Un-normalized components if using reward shaping, where the maximum is returned if not solved:
+
+ - Reaching: in [0, 0.1], proportional to the distance between the gripper and the closest object
+ - Grasping: in {0, 0.35}, nonzero if the gripper is grasping an object
+ - Lifting: in {0, [0.35, 0.5]}, nonzero only if object is grasped; proportional to lifting height
+ - Hovering: in {0, [0.5, 0.7]}, nonzero only if object is lifted; proportional to distance from object to bin
+
+ Note that a successfully completed task (object in bin) will return 1.0 per object irregardless of whether the
+ environment is using sparse or shaped rewards
+
+ Note that the final reward is normalized and scaled by reward_scale / 4.0 (or 1.0 if only a single object is
+ being used) as well so that the max score is equal to reward_scale
+
+ Args:
+ action (np.array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ # compute sparse rewards
+ self._check_success()
+ reward = np.sum(self.objects_in_bins)
+
+ # add in shaped rewards
+ if self.reward_shaping:
+ staged_rewards = self.staged_rewards()
+ reward += max(staged_rewards)
+ if self.reward_scale is not None:
+ reward *= self.reward_scale
+ if self.single_object_mode == 0:
+ reward /= 4.0
+ return reward
+
+ def staged_rewards(self):
+ """
+ Returns staged rewards based on current physical states.
+ Stages consist of reaching, grasping, lifting, and hovering.
+
+ Returns:
+ 4-tuple:
+
+ - (float) reaching reward
+ - (float) grasping reward
+ - (float) lifting reward
+ - (float) hovering reward
+ """
+
+ reach_mult = 0.1
+ grasp_mult = 0.35
+ lift_mult = 0.5
+ hover_mult = 0.7
+
+ # filter out objects that are already in the correct bins
+ active_objs = []
+ for i, obj in enumerate(self.objects):
+ if self.objects_in_bins[i]:
+ continue
+ active_objs.append(obj)
+
+ # reaching reward governed by distance to closest object
+ r_reach = 0.0
+ if active_objs:
+ # get reaching reward via minimum distance to a target object
+ dists = [
+ self._gripper_to_target(
+ gripper=self.robots[0].gripper,
+ target=active_obj.root_body,
+ target_type="body",
+ return_distance=True,
+ )
+ for active_obj in active_objs
+ ]
+ r_reach = (1 - np.tanh(10.0 * min(dists))) * reach_mult
+
+ # grasping reward for touching any objects of interest
+ r_grasp = (
+ int(
+ self._check_grasp(
+ gripper=self.robots[0].gripper,
+ object_geoms=[g for active_obj in active_objs for g in active_obj.contact_geoms],
+ )
+ )
+ * grasp_mult
+ )
+
+ # lifting reward for picking up an object
+ r_lift = 0.0
+ if active_objs and r_grasp > 0.0:
+ z_target = self.bin2_pos[2] + 0.25
+ object_z_locs = self.sim.data.body_xpos[[self.obj_body_id[active_obj.name] for active_obj in active_objs]][
+ :, 2
+ ]
+ z_dists = np.maximum(z_target - object_z_locs, 0.0)
+ r_lift = grasp_mult + (1 - np.tanh(15.0 * min(z_dists))) * (lift_mult - grasp_mult)
+
+ # hover reward for getting object above bin
+ r_hover = 0.0
+ if active_objs:
+ target_bin_ids = [self.object_to_id[active_obj.name.lower()] for active_obj in active_objs]
+ # segment objects into left of the bins and above the bins
+ object_xy_locs = self.sim.data.body_xpos[[self.obj_body_id[active_obj.name] for active_obj in active_objs]][
+ :, :2
+ ]
+ y_check = (
+ np.abs(object_xy_locs[:, 1] - self.target_bin_placements[target_bin_ids, 1]) < self.bin_size[1] / 4.0
+ )
+ x_check = (
+ np.abs(object_xy_locs[:, 0] - self.target_bin_placements[target_bin_ids, 0]) < self.bin_size[0] / 4.0
+ )
+ objects_above_bins = np.logical_and(x_check, y_check)
+ objects_not_above_bins = np.logical_not(objects_above_bins)
+ dists = np.linalg.norm(self.target_bin_placements[target_bin_ids, :2] - object_xy_locs, axis=1)
+ # objects to the left get r_lift added to hover reward,
+ # those on the right get max(r_lift) added (to encourage dropping)
+ r_hover_all = np.zeros(len(active_objs))
+ r_hover_all[objects_above_bins] = lift_mult + (1 - np.tanh(10.0 * dists[objects_above_bins])) * (
+ hover_mult - lift_mult
+ )
+ r_hover_all[objects_not_above_bins] = r_lift + (1 - np.tanh(10.0 * dists[objects_not_above_bins])) * (
+ hover_mult - lift_mult
+ )
+ r_hover = np.max(r_hover_all)
+
+ return r_reach, r_grasp, r_lift, r_hover
+
+ def not_in_bin(self, obj_pos, bin_id):
+
+ bin_x_low = self.bin2_pos[0]
+ bin_y_low = self.bin2_pos[1]
+ if bin_id == 0 or bin_id == 2:
+ bin_x_low -= self.bin_size[0] / 2
+ if bin_id < 2:
+ bin_y_low -= self.bin_size[1] / 2
+
+ bin_x_high = bin_x_low + self.bin_size[0] / 2
+ bin_y_high = bin_y_low + self.bin_size[1] / 2
+
+ res = True
+ if (
+ bin_x_low < obj_pos[0] < bin_x_high
+ and bin_y_low < obj_pos[1] < bin_y_high
+ and self.bin2_pos[2] < obj_pos[2] < self.bin2_pos[2] + 0.1
+ ):
+ res = False
+ return res
+
+ def _get_placement_initializer(self):
+ """
+ Helper function for defining placement initializer and object sampling bounds.
+ """
+ self.placement_initializer = SequentialCompositeSampler(name="ObjectSampler")
+
+ # can sample anywhere in bin
+ bin_x_half = self.model.mujoco_arena.table_full_size[0] / 2 - 0.05
+ bin_y_half = self.model.mujoco_arena.table_full_size[1] / 2 - 0.05
+
+ # each object should just be sampled in the bounds of the bin (with some tolerance)
+ self.placement_initializer.append_sampler(
+ sampler=UniformRandomSampler(
+ name="CollisionObjectSampler",
+ mujoco_objects=self.objects,
+ x_range=[-bin_x_half, bin_x_half],
+ y_range=[-bin_y_half, bin_y_half],
+ rotation=self.z_rotation,
+ rotation_axis="z",
+ ensure_object_boundary_in_range=True,
+ ensure_valid_placement=True,
+ reference_pos=self.bin1_pos,
+ z_offset=self.z_offset,
+ )
+ )
+
+ # each visual object should just be at the center of each target bin
+ index = 0
+ for vis_obj in self.visual_objects:
+
+ # get center of target bin
+ bin_x_low = self.bin2_pos[0]
+ bin_y_low = self.bin2_pos[1]
+ if index == 0 or index == 2:
+ bin_x_low -= self.bin_size[0] / 2
+ if index < 2:
+ bin_y_low -= self.bin_size[1] / 2
+ bin_x_high = bin_x_low + self.bin_size[0] / 2
+ bin_y_high = bin_y_low + self.bin_size[1] / 2
+ bin_center = np.array(
+ [
+ (bin_x_low + bin_x_high) / 2.0,
+ (bin_y_low + bin_y_high) / 2.0,
+ ]
+ )
+
+ # placement is relative to object bin, so compute difference and send to placement initializer
+ rel_center = bin_center - self.bin1_pos[:2]
+
+ self.placement_initializer.append_sampler(
+ sampler=UniformRandomSampler(
+ name=f"{vis_obj.name}ObjectSampler",
+ mujoco_objects=vis_obj,
+ x_range=[rel_center[0], rel_center[0]],
+ y_range=[rel_center[1], rel_center[1]],
+ rotation=0.0,
+ rotation_axis="z",
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=False,
+ reference_pos=self.bin1_pos,
+ z_offset=self.bin2_pos[2] - self.bin1_pos[2],
+ )
+ )
+ index += 1
+
+ def _construct_visual_objects(self):
+ """
+ Function that can be overriden by subclasses to load different objects.
+ """
+ self.visual_objects = []
+ for vis_obj_cls, obj_name in zip(
+ (MilkVisualObject, BreadVisualObject, CerealVisualObject, CanVisualObject),
+ self.obj_names,
+ ):
+ vis_name = "Visual" + obj_name
+ vis_obj = vis_obj_cls(name=vis_name)
+ self.visual_objects.append(vis_obj)
+
+ def _construct_objects(self):
+ """
+ Function that can be overriden by subclasses to load different objects.
+ """
+ self.objects = []
+ for obj_cls, obj_name in zip(
+ (MilkObject, BreadObject, CerealObject, CanObject),
+ self.obj_names,
+ ):
+ obj = obj_cls(name=obj_name)
+ self.objects.append(obj)
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose accordingly
+ xpos = self.robots[0].robot_model.base_xpos_offset["bins"]
+ self.robots[0].robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = BinsArena(
+ bin1_pos=self.bin1_pos, table_full_size=self.table_full_size, table_friction=self.table_friction
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # store some arena attributes
+ self.bin_size = mujoco_arena.table_full_size
+
+ # make objects
+ self._construct_visual_objects()
+ self._construct_objects()
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ mujoco_objects=self.visual_objects + self.objects,
+ )
+
+ # Generate placement initializer
+ self._get_placement_initializer()
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Additional object references from this env
+ self.obj_body_id = {}
+ self.obj_geom_id = {}
+
+ # object-specific ids
+ for obj in self.visual_objects + self.objects:
+ self.obj_body_id[obj.name] = self.sim.model.body_name2id(obj.root_body)
+ self.obj_geom_id[obj.name] = [self.sim.model.geom_name2id(g) for g in obj.contact_geoms]
+
+ # keep track of which objects are in their corresponding bins
+ self.objects_in_bins = np.zeros(len(self.objects))
+
+ # target locations in bin for each object type
+ self.target_bin_placements = np.zeros((len(self.objects), 3))
+ for i, obj in enumerate(self.objects):
+ bin_id = i
+ bin_x_low = self.bin2_pos[0]
+ bin_y_low = self.bin2_pos[1]
+ if bin_id == 0 or bin_id == 2:
+ bin_x_low -= self.bin_size[0] / 2.0
+ if bin_id < 2:
+ bin_y_low -= self.bin_size[1] / 2.0
+ bin_x_low += self.bin_size[0] / 4.0
+ bin_y_low += self.bin_size[1] / 4.0
+ self.target_bin_placements[i, :] = [bin_x_low, bin_y_low, self.bin2_pos[2]]
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ pf = self.robots[0].robot_model.naming_prefix
+ modality = "object"
+
+ # Reset obj sensor mappings
+ self.object_id_to_sensors = {}
+
+ # for conversion to relative gripper frame
+ @sensor(modality=modality)
+ def world_pose_in_gripper(obs_cache):
+ return (
+ T.pose_inv(T.pose2mat((obs_cache[f"{pf}eef_pos"], obs_cache[f"{pf}eef_quat"])))
+ if f"{pf}eef_pos" in obs_cache and f"{pf}eef_quat" in obs_cache
+ else np.eye(4)
+ )
+
+ sensors = [world_pose_in_gripper]
+ names = ["world_pose_in_gripper"]
+ enableds = [True]
+ actives = [False]
+
+ for i, obj in enumerate(self.objects):
+ # Create object sensors
+ using_obj = self.single_object_mode == 0 or self.object_id == i
+ obj_sensors, obj_sensor_names = self._create_obj_sensors(obj_name=obj.name, modality=modality)
+ sensors += obj_sensors
+ names += obj_sensor_names
+ enableds += [using_obj] * 4
+ actives += [using_obj] * 4
+ self.object_id_to_sensors[i] = obj_sensor_names
+
+ if self.single_object_mode == 1:
+ # This is randomly sampled object, so we need to include object id as observation
+ @sensor(modality=modality)
+ def obj_id(obs_cache):
+ return self.object_id
+
+ sensors.append(obj_id)
+ names.append("obj_id")
+ enableds.append(True)
+ actives.append(True)
+
+ # Create observables
+ for name, s, enabled, active in zip(names, sensors, enableds, actives):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ enabled=enabled,
+ active=active,
+ )
+
+ return observables
+
+ def _create_obj_sensors(self, obj_name, modality="object"):
+ """
+ Helper function to create sensors for a given object. This is abstracted in a separate function call so that we
+ don't have local function naming collisions during the _setup_observables() call.
+
+ Args:
+ obj_name (str): Name of object to create sensors for
+ modality (str): Modality to assign to all sensors
+
+ Returns:
+ 2-tuple:
+ sensors (list): Array of sensors for the given obj
+ names (list): array of corresponding observable names
+ """
+ pf = self.robots[0].robot_model.naming_prefix
+
+ @sensor(modality=modality)
+ def obj_pos(obs_cache):
+ return np.array(self.sim.data.body_xpos[self.obj_body_id[obj_name]])
+
+ @sensor(modality=modality)
+ def obj_quat(obs_cache):
+ return T.convert_quat(self.sim.data.body_xquat[self.obj_body_id[obj_name]], to="xyzw")
+
+ @sensor(modality=modality)
+ def obj_to_eef_pos(obs_cache):
+ # Immediately return default value if cache is empty
+ if any(
+ [name not in obs_cache for name in [f"{obj_name}_pos", f"{obj_name}_quat", "world_pose_in_gripper"]]
+ ):
+ return np.zeros(3)
+ obj_pose = T.pose2mat((obs_cache[f"{obj_name}_pos"], obs_cache[f"{obj_name}_quat"]))
+ rel_pose = T.pose_in_A_to_pose_in_B(obj_pose, obs_cache["world_pose_in_gripper"])
+ rel_pos, rel_quat = T.mat2pose(rel_pose)
+ obs_cache[f"{obj_name}_to_{pf}eef_quat"] = rel_quat
+ return rel_pos
+
+ @sensor(modality=modality)
+ def obj_to_eef_quat(obs_cache):
+ return (
+ obs_cache[f"{obj_name}_to_{pf}eef_quat"] if f"{obj_name}_to_{pf}eef_quat" in obs_cache else np.zeros(4)
+ )
+
+ sensors = [obj_pos, obj_quat, obj_to_eef_pos, obj_to_eef_quat]
+ names = [f"{obj_name}_pos", f"{obj_name}_quat", f"{obj_name}_to_{pf}eef_pos", f"{obj_name}_to_{pf}eef_quat"]
+
+ return sensors, names
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ # Reset all object positions using initializer sampler if we're not directly loading from an xml
+ if not self.deterministic_reset:
+
+ # Sample from the placement initializer for all objects
+ object_placements = self.placement_initializer.sample()
+
+ # Loop through all objects and reset their positions
+ for obj_pos, obj_quat, obj in object_placements.values():
+ # Set the visual object body locations
+ if "visual" in obj.name.lower():
+ self.sim.model.body_pos[self.obj_body_id[obj.name]] = obj_pos
+ self.sim.model.body_quat[self.obj_body_id[obj.name]] = obj_quat
+ else:
+ # Set the collision object joints
+ self.sim.data.set_joint_qpos(obj.joints[0], np.concatenate([np.array(obj_pos), np.array(obj_quat)]))
+
+ # Set the bins to the desired position
+ self.sim.model.body_pos[self.sim.model.body_name2id("bin1")] = self.bin1_pos
+ self.sim.model.body_pos[self.sim.model.body_name2id("bin2")] = self.bin2_pos
+
+ # Move objects out of the scene depending on the mode
+ obj_names = {obj.name for obj in self.objects}
+ if self.single_object_mode == 1:
+ self.obj_to_use = random.choice(list(obj_names))
+ for obj_type, i in self.object_to_id.items():
+ if obj_type.lower() in self.obj_to_use.lower():
+ self.object_id = i
+ break
+ elif self.single_object_mode == 2:
+ self.obj_to_use = self.objects[self.object_id].name
+ if self.single_object_mode in {1, 2}:
+ obj_names.remove(self.obj_to_use)
+ self.clear_objects(list(obj_names))
+
+ # Make sure to update sensors' active and enabled states
+ if self.single_object_mode != 0:
+ for i, sensor_names in self.object_id_to_sensors.items():
+ for name in sensor_names:
+ # Set all of these sensors to be enabled and active if this is the active object, else False
+ self._observables[name].set_enabled(i == self.object_id)
+ self._observables[name].set_active(i == self.object_id)
+
+ def _check_success(self):
+ """
+ Check if all objects have been successfully placed in their corresponding bins.
+
+ Returns:
+ bool: True if all objects are placed correctly
+ """
+ # remember objects that are in the correct bins
+ gripper_site_pos = self.sim.data.site_xpos[self.robots[0].eef_site_id]
+ for i, obj in enumerate(self.objects):
+ obj_str = obj.name
+ obj_pos = self.sim.data.body_xpos[self.obj_body_id[obj_str]]
+ dist = np.linalg.norm(gripper_site_pos - obj_pos)
+ r_reach = 1 - np.tanh(10.0 * dist)
+ self.objects_in_bins[i] = int((not self.not_in_bin(obj_pos, i)) and r_reach < 0.6)
+
+ # returns True if a single object is in the correct bin
+ if self.single_object_mode in {1, 2}:
+ return np.sum(self.objects_in_bins) > 0
+
+ # returns True if all objects are in correct bins
+ return np.sum(self.objects_in_bins) == len(self.objects)
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualize gripper site proportional to the distance to the closest object.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "grippers" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+
+ # Color the gripper visualization site according to its distance to the closest object
+ if vis_settings["grippers"]:
+ # find closest object
+ dists = [
+ self._gripper_to_target(
+ gripper=self.robots[0].gripper,
+ target=obj.root_body,
+ target_type="body",
+ return_distance=True,
+ )
+ for obj in self.objects
+ ]
+ closest_obj_id = np.argmin(dists)
+ # Visualize the distance to this target
+ self._visualize_gripper_to_target(
+ gripper=self.robots[0].gripper,
+ target=self.objects[closest_obj_id].root_body,
+ target_type="body",
+ )
+
+
+class PickPlaceSingle(PickPlace):
+ """
+ Easier version of task - place one object into its bin.
+ A new object is sampled on every reset.
+ """
+
+ def __init__(self, **kwargs):
+ assert "single_object_mode" not in kwargs, "invalid set of arguments"
+ super().__init__(single_object_mode=1, **kwargs)
+
+
+class PickPlaceMilk(PickPlace):
+ """
+ Easier version of task - place one milk into its bin.
+ """
+
+ def __init__(self, **kwargs):
+ assert "single_object_mode" not in kwargs and "object_type" not in kwargs, "invalid set of arguments"
+ super().__init__(single_object_mode=2, object_type="milk", **kwargs)
+
+
+class PickPlaceBread(PickPlace):
+ """
+ Easier version of task - place one bread into its bin.
+ """
+
+ def __init__(self, **kwargs):
+ assert "single_object_mode" not in kwargs and "object_type" not in kwargs, "invalid set of arguments"
+ super().__init__(single_object_mode=2, object_type="bread", **kwargs)
+
+
+class PickPlaceCereal(PickPlace):
+ """
+ Easier version of task - place one cereal into its bin.
+ """
+
+ def __init__(self, **kwargs):
+ assert "single_object_mode" not in kwargs and "object_type" not in kwargs, "invalid set of arguments"
+ super().__init__(single_object_mode=2, object_type="cereal", **kwargs)
+
+
+class PickPlaceCan(PickPlace):
+ """
+ Easier version of task - place one can into its bin.
+ """
+
+ def __init__(self, **kwargs):
+ assert "single_object_mode" not in kwargs and "object_type" not in kwargs, "invalid set of arguments"
+ super().__init__(single_object_mode=2, object_type="can", **kwargs)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/single_arm_env.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/single_arm_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bc3c9ac41557d644fdc7476bb1b1c488fd43b9e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/single_arm_env.py
@@ -0,0 +1,72 @@
+import numpy as np
+
+from robosuite.environments.manipulation.manipulation_env import ManipulationEnv
+from robosuite.robots import SingleArm
+from robosuite.utils.transform_utils import mat2quat
+
+
+class SingleArmEnv(ManipulationEnv):
+ """
+ A manipulation environment intended for a single robot arm.
+ """
+
+ def _load_model(self):
+ """
+ Verifies correct robot model is loaded
+ """
+ super()._load_model()
+
+ # Verify the correct robot has been loaded
+ assert isinstance(
+ self.robots[0], SingleArm
+ ), "Error: Expected one single-armed robot! Got {} type instead.".format(type(self.robots[0]))
+
+ def _check_robot_configuration(self, robots):
+ """
+ Sanity check to make sure the inputted robots and configuration is acceptable
+
+ Args:
+ robots (str or list of str): Robots to instantiate within this env
+ """
+ super()._check_robot_configuration(robots)
+ if type(robots) is list:
+ assert len(robots) == 1, "Error: Only one robot should be inputted for this task!"
+
+ @property
+ def _eef_xpos(self):
+ """
+ Grabs End Effector position
+
+ Returns:
+ np.array: End effector(x,y,z)
+ """
+ return np.array(self.sim.data.site_xpos[self.robots[0].eef_site_id])
+
+ @property
+ def _eef_xmat(self):
+ """
+ End Effector orientation as a rotation matrix
+ Note that this draws the orientation from the "ee" site, NOT the gripper site, since the gripper
+ orientations are inconsistent!
+
+ Returns:
+ np.array: (3,3) End Effector orientation matrix
+ """
+ pf = self.robots[0].gripper.naming_prefix
+
+ if self.env_configuration == "bimanual":
+ return np.array(self.sim.data.site_xmat[self.sim.model.site_name2id(pf + "right_grip_site")]).reshape(3, 3)
+ else:
+ return np.array(self.sim.data.site_xmat[self.sim.model.site_name2id(pf + "grip_site")]).reshape(3, 3)
+
+ @property
+ def _eef_xquat(self):
+ """
+ End Effector orientation as a (x,y,z,w) quaternion
+ Note that this draws the orientation from the "ee" site, NOT the gripper site, since the gripper
+ orientations are inconsistent!
+
+ Returns:
+ np.array: (x,y,z,w) End Effector quaternion
+ """
+ return mat2quat(self._eef_xmat)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/stack.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/stack.py
new file mode 100644
index 0000000000000000000000000000000000000000..992ab7e2111a11d1efa836c7c06003249bd7ae2c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/stack.py
@@ -0,0 +1,499 @@
+from collections import OrderedDict
+
+import numpy as np
+
+from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
+from robosuite.models.arenas import TableArena
+from robosuite.models.objects import BoxObject
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.mjcf_utils import CustomMaterial
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import UniformRandomSampler
+from robosuite.utils.transform_utils import convert_quat
+
+
+class Stack(SingleArmEnv):
+ """
+ This class corresponds to the stacking task for a single robot arm.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be a single single-arm robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment (default is "default").
+ For most single arm environments, this argument has no impact on the robot setup.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ table_full_size (3-tuple): x, y, and z dimensions of the table.
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ the table.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ placement_initializer (ObjectPositionSampler): if provided, will
+ be used to place objects on every reset, else a UniformRandomSampler
+ is used by default.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ AssertionError: [Invalid number of robots specified]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1.0, 5e-3, 1e-4),
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ placement_initializer=None,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # settings for table top
+ self.table_full_size = table_full_size
+ self.table_friction = table_friction
+ self.table_offset = np.array((0, 0, 0.8))
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ # object placement initializer
+ self.placement_initializer = placement_initializer
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of 2.0 is provided if the red block is stacked on the green block
+
+ Un-normalized components if using reward shaping:
+
+ - Reaching: in [0, 0.25], to encourage the arm to reach the cube
+ - Grasping: in {0, 0.25}, non-zero if arm is grasping the cube
+ - Lifting: in {0, 1}, non-zero if arm has lifted the cube
+ - Aligning: in [0, 0.5], encourages aligning one cube over the other
+ - Stacking: in {0, 2}, non-zero if cube is stacked on other cube
+
+ The reward is max over the following:
+
+ - Reaching + Grasping
+ - Lifting + Aligning
+ - Stacking
+
+ The sparse reward only consists of the stacking component.
+
+ Note that the final reward is normalized and scaled by
+ reward_scale / 2.0 as well so that the max score is equal to reward_scale
+
+ Args:
+ action (np array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ r_reach, r_lift, r_stack = self.staged_rewards()
+ if self.reward_shaping:
+ reward = max(r_reach, r_lift, r_stack)
+ else:
+ reward = 2.0 if r_stack > 0 else 0.0
+
+ if self.reward_scale is not None:
+ reward *= self.reward_scale / 2.0
+
+ return reward
+
+ def staged_rewards(self):
+ """
+ Helper function to calculate staged rewards based on current physical states.
+
+ Returns:
+ 3-tuple:
+
+ - (float): reward for reaching and grasping
+ - (float): reward for lifting and aligning
+ - (float): reward for stacking
+ """
+ # reaching is successful when the gripper site is close to the center of the cube
+ cubeA_pos = self.sim.data.body_xpos[self.cubeA_body_id]
+ cubeB_pos = self.sim.data.body_xpos[self.cubeB_body_id]
+ gripper_site_pos = self.sim.data.site_xpos[self.robots[0].eef_site_id]
+ dist = np.linalg.norm(gripper_site_pos - cubeA_pos)
+ r_reach = (1 - np.tanh(10.0 * dist)) * 0.25
+
+ # grasping reward
+ grasping_cubeA = self._check_grasp(gripper=self.robots[0].gripper, object_geoms=self.cubeA)
+ if grasping_cubeA:
+ r_reach += 0.25
+
+ # lifting is successful when the cube is above the table top by a margin
+ cubeA_height = cubeA_pos[2]
+ table_height = self.table_offset[2]
+ cubeA_lifted = cubeA_height > table_height + 0.04
+ r_lift = 1.0 if cubeA_lifted else 0.0
+
+ # Aligning is successful when cubeA is right above cubeB
+ if cubeA_lifted:
+ horiz_dist = np.linalg.norm(np.array(cubeA_pos[:2]) - np.array(cubeB_pos[:2]))
+ r_lift += 0.5 * (1 - np.tanh(horiz_dist))
+
+ # stacking is successful when the block is lifted and the gripper is not holding the object
+ r_stack = 0
+ cubeA_touching_cubeB = self.check_contact(self.cubeA, self.cubeB)
+ if not grasping_cubeA and r_lift > 0 and cubeA_touching_cubeB:
+ r_stack = 2.0
+
+ return r_reach, r_lift, r_stack
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose accordingly
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = TableArena(
+ table_full_size=self.table_full_size,
+ table_friction=self.table_friction,
+ table_offset=self.table_offset,
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # initialize objects of interest
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "1 1",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ redwood = CustomMaterial(
+ texture="WoodRed",
+ tex_name="redwood",
+ mat_name="redwood_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ greenwood = CustomMaterial(
+ texture="WoodGreen",
+ tex_name="greenwood",
+ mat_name="greenwood_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.cubeA = BoxObject(
+ name="cubeA",
+ size_min=[0.02, 0.02, 0.02],
+ size_max=[0.02, 0.02, 0.02],
+ rgba=[1, 0, 0, 1],
+ material=redwood,
+ )
+ self.cubeB = BoxObject(
+ name="cubeB",
+ size_min=[0.025, 0.025, 0.025],
+ size_max=[0.025, 0.025, 0.025],
+ rgba=[0, 1, 0, 1],
+ material=greenwood,
+ )
+ cubes = [self.cubeA, self.cubeB]
+ # Create placement initializer
+ if self.placement_initializer is not None:
+ self.placement_initializer.reset()
+ self.placement_initializer.add_objects(cubes)
+ else:
+ self.placement_initializer = UniformRandomSampler(
+ name="ObjectSampler",
+ mujoco_objects=cubes,
+ x_range=[-0.08, 0.08],
+ y_range=[-0.08, 0.08],
+ rotation=None,
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=True,
+ reference_pos=self.table_offset,
+ z_offset=0.01,
+ )
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ mujoco_objects=cubes,
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Additional object references from this env
+ self.cubeA_body_id = self.sim.model.body_name2id(self.cubeA.root_body)
+ self.cubeB_body_id = self.sim.model.body_name2id(self.cubeB.root_body)
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ # Reset all object positions using initializer sampler if we're not directly loading from an xml
+ if not self.deterministic_reset:
+
+ # Sample from the placement initializer for all objects
+ object_placements = self.placement_initializer.sample()
+
+ # Loop through all objects and reset their positions
+ for obj_pos, obj_quat, obj in object_placements.values():
+ self.sim.data.set_joint_qpos(obj.joints[0], np.concatenate([np.array(obj_pos), np.array(obj_quat)]))
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ pf = self.robots[0].robot_model.naming_prefix
+ modality = "object"
+
+ # position and rotation of the first cube
+ @sensor(modality=modality)
+ def cubeA_pos(obs_cache):
+ return np.array(self.sim.data.body_xpos[self.cubeA_body_id])
+
+ @sensor(modality=modality)
+ def cubeA_quat(obs_cache):
+ return convert_quat(np.array(self.sim.data.body_xquat[self.cubeA_body_id]), to="xyzw")
+
+ @sensor(modality=modality)
+ def cubeB_pos(obs_cache):
+ return np.array(self.sim.data.body_xpos[self.cubeB_body_id])
+
+ @sensor(modality=modality)
+ def cubeB_quat(obs_cache):
+ return convert_quat(np.array(self.sim.data.body_xquat[self.cubeB_body_id]), to="xyzw")
+
+ @sensor(modality=modality)
+ def gripper_to_cubeA(obs_cache):
+ return (
+ obs_cache["cubeA_pos"] - obs_cache[f"{pf}eef_pos"]
+ if "cubeA_pos" in obs_cache and f"{pf}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ @sensor(modality=modality)
+ def gripper_to_cubeB(obs_cache):
+ return (
+ obs_cache["cubeB_pos"] - obs_cache[f"{pf}eef_pos"]
+ if "cubeB_pos" in obs_cache and f"{pf}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ @sensor(modality=modality)
+ def cubeA_to_cubeB(obs_cache):
+ return (
+ obs_cache["cubeB_pos"] - obs_cache["cubeA_pos"]
+ if "cubeA_pos" in obs_cache and "cubeB_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ sensors = [cubeA_pos, cubeA_quat, cubeB_pos, cubeB_quat, gripper_to_cubeA, gripper_to_cubeB, cubeA_to_cubeB]
+ names = [s.__name__ for s in sensors]
+
+ # Create observables
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _check_success(self):
+ """
+ Check if blocks are stacked correctly.
+
+ Returns:
+ bool: True if blocks are correctly stacked
+ """
+ _, _, r_stack = self.staged_rewards()
+ return r_stack > 0
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualize gripper site proportional to the distance to the cube.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "grippers" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+
+ # Color the gripper visualization site according to its distance to the cube
+ if vis_settings["grippers"]:
+ self._visualize_gripper_to_target(gripper=self.robots[0].gripper, target=self.cubeA)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/tool_hang.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/tool_hang.py
new file mode 100644
index 0000000000000000000000000000000000000000..df5d63806a1e612a0e359b3466b5483f39706854
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/tool_hang.py
@@ -0,0 +1,736 @@
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
+from robosuite.models.arenas import TableArena
+from robosuite.models.objects import HookFrame, RatchetingWrenchObject, StandWithMount
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.mjcf_utils import CustomMaterial
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import SequentialCompositeSampler, UniformRandomSampler
+from robosuite.utils.sim_utils import check_contact
+
+
+class ToolHang(SingleArmEnv):
+ """
+ This class corresponds to the tool hang task for a single robot arm.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be a single single-arm robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment (default is "default").
+ For most single arm environments, this argument has no impact on the robot setup.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ table_full_size (3-tuple): x, y, and z dimensions of the table.
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ the table.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ AssertionError: [Invalid number of robots specified]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1.0, 5e-3, 1e-4),
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # settings for table top
+ self.table_full_size = table_full_size
+ self.table_friction = table_friction
+ self.table_offset = np.array((0, 0, 0.8))
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Args:
+ action (np array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ reward = 0.0
+
+ # sparse completion reward
+ if self._check_success():
+ reward = 1.0
+
+ # Scale reward if requested
+ if self.reward_scale is not None:
+ reward *= self.reward_scale
+
+ return reward
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+
+ Made some aspects easier than the real world task:
+ - increase base thickness for stand
+ - increase mount width to 1.2 cm
+ - add hole visualization
+ - reduce hook height on stand a little
+ - reduce tool ends height a little
+ """
+ super()._load_model()
+
+ # Adjust base pose accordingly
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = TableArena(
+ table_full_size=self.table_full_size,
+ table_friction=self.table_friction,
+ table_offset=self.table_offset,
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # Modify default agentview camera
+ mujoco_arena.set_camera(
+ camera_name="agentview",
+ pos=[0.4837275266036987, 0.2505579098815722, 1.2639379055124524],
+ quat=[0.39713290333747864, 0.27807527780532837, 0.5016612410545349, 0.7164464592933655],
+ )
+
+ # Add sideview
+ mujoco_arena.set_camera(
+ camera_name="sideview",
+ pos=[0.4837275266036987, 0.2505579098815722, 1.2139379055124524],
+ quat=[0.39713290333747864, 0.27807527780532837, 0.5016612410545349, 0.7164464592933655],
+ )
+
+ # Create stand, frame, and tool
+ self.stand_args = dict(
+ name="stand",
+ size=(
+ (12.0 / 100.0),
+ (14.0 / 100.0),
+ (16.0 / 100.0),
+ ), # 14 cm x 12 cm base, with 16 cm height (in real world we cut the 32 cm height stand in half as well)
+ mount_location=(0.0, (4.5 / 100.0)), # 2.5 cm from right edge, so 4.5 cm to the right
+ mount_width=(1.2 / 100.0), # 1.2 cm thickness for rod cavity
+ wall_thickness=(0.1 / 100.0), # about 0.1-0.2 cm thickness for walls
+ base_thickness=(1 / 100.0), # increased thickness to 1 cm (different from real)
+ initialize_on_side=False,
+ add_hole_vis=True,
+ density=50000.0,
+ solref=(0.02, 1.0),
+ solimp=(0.998, 0.998, 0.001),
+ )
+ self.stand = StandWithMount(**self.stand_args)
+
+ self.frame_args = dict(
+ name="frame",
+ frame_length=(9.5 / 100.0), # 9.5 cm wide
+ frame_height=(18.0 / 100.0), # 18 cm tall (in real world we cut the physical 36 cm rod in half as well)
+ frame_thickness=(0.75 / 100.0), # 0.75 cm thick
+ hook_height=(1.2 / 100.0), # lowered to 1.2 cm tall (instead of 1.7 cm in real world)
+ grip_location=((9.0 - 3.0) / 100.0)
+ - (0.75 / 200.0), # move up by half height of frame minus half height of grip minus half thickness
+ grip_size=((2.54 / 200.0), (6.35 / 200.0)), # 6.35 cm length, 2.54 cm thick
+ tip_size=(
+ (2.54 / 200.0),
+ (0.2 / 200.0),
+ (0.65 / 200.0),
+ (1.905 / 100.0),
+ ), # 1-inch cylinder, 0.65 inch solder tip
+ density=500.0,
+ solref=(0.02, 1.0),
+ solimp=(0.998, 0.998, 0.001),
+ )
+ self.frame = HookFrame(**self.frame_args)
+
+ self.real_tool_args = dict(
+ name="tool",
+ handle_size=(
+ (16.5 / 200.0),
+ (1.75 / 200.0),
+ (0.32 / 200.0),
+ ), # 16.5 cm length, 1.75 cm width, 0.32 cm thick (1.5 cm with foam)
+ outer_radius_1=(3.5 / 200.0), # larger hole 3.5 cm outer diameter
+ inner_radius_1=(2.1 / 200.0), # reduced larger hole 2.1 cm inner diameter (from real world 2.3 cm)
+ height_1=(0.7 / 200.0), # 0.7 cm height
+ outer_radius_2=(3.0 / 200.0), # smaller hole 3 cm outer diameter
+ inner_radius_2=(2.0 / 200.0), # smaller hole 2 cm outer diameter
+ height_2=(0.7 / 200.0), # 0.7 cm height
+ ngeoms=8,
+ grip_size=((3 / 200.0), (8.0 / 200.0)), # 8 cm length, 3 cm thick
+ density=2000.0,
+ solref=(0.02, 1.0),
+ solimp=(0.998, 0.998, 0.001),
+ friction=(0.95, 0.3, 0.1),
+ )
+
+ self.tool_args = self.real_tool_args
+ self.tool = RatchetingWrenchObject(**self.tool_args)
+
+ # Create placement initializer
+ self._get_placement_initializer()
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ mujoco_objects=[self.stand, self.frame, self.tool],
+ )
+
+ def _get_placement_initializer(self):
+ """
+ Helper function for defining placement initializer and object sampling bounds
+ """
+ # Create placement initializer
+ self.placement_initializer = SequentialCompositeSampler(name="ObjectSampler")
+
+ # Pre-define settings for each object's placement
+ objects = [self.stand, self.frame, self.tool]
+ x_centers = [-self.table_full_size[0] * 0.1, -self.table_full_size[0] * 0.05, self.table_full_size[0] * 0.05]
+ y_centers = [0.0, -self.table_full_size[1] * 0.3, -self.table_full_size[1] * 0.25]
+ x_tols = [0.0, 0.02, 0.02]
+ y_tols = [0.0, 0.02, 0.02]
+ rot_centers = [0, (-np.pi / 2) + (np.pi / 6), (-np.pi / 2) - (np.pi / 9.0)]
+ rot_tols = [0.0, np.pi / 18, np.pi / 18.0]
+ rot_axes = ["z", "y", "z"]
+ z_offsets = [
+ 0.001,
+ (self.frame_args["frame_thickness"] - self.frame_args["frame_height"]) / 2.0
+ + 0.001
+ + (self.stand_args["base_thickness"] / 2.0)
+ + (self.frame_args["grip_size"][1]),
+ 0.001,
+ ]
+ if ("tip_size" in self.frame_args) and (self.frame_args["tip_size"] is not None):
+ z_offsets[1] -= self.frame_args["tip_size"][0] + 2.0 * self.frame_args["tip_size"][3]
+ for obj, x, y, x_tol, y_tol, r, r_tol, r_axis, z_offset in zip(
+ objects, x_centers, y_centers, x_tols, y_tols, rot_centers, rot_tols, rot_axes, z_offsets
+ ):
+ # Create sampler for this object and add it to the sequential sampler
+ self.placement_initializer.append_sampler(
+ sampler=UniformRandomSampler(
+ name=f"{obj.name}ObjectSampler",
+ mujoco_objects=obj,
+ x_range=[x - x_tol, x + x_tol],
+ y_range=[y - y_tol, y + y_tol],
+ rotation=[r - r_tol, r + r_tol],
+ rotation_axis=r_axis,
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=False,
+ reference_pos=self.table_offset,
+ z_offset=z_offset,
+ )
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Additional object references from this env
+ self.obj_body_id = dict(
+ stand=self.sim.model.body_name2id(self.stand.root_body),
+ frame=self.sim.model.body_name2id(self.frame.root_body),
+ tool=self.sim.model.body_name2id(self.tool.root_body),
+ )
+
+ # Important sites:
+ # tool_hole1_center - for checking hanging
+ # frame_hang_site, frame_mount_site, frame_intersection_site - for orienting the hook, and checking hanging
+ # stand_mount_site - for checking that stand base is upright
+ self.obj_site_id = dict(
+ tool_hole1_center=self.sim.model.site_name2id("tool_hole1_center"), # center of one end of wrench
+ # tool_hole2_center=self.sim.model.site_name2id("tool_hole2_center"), # center of other end of wrench
+ frame_hang_site=self.sim.model.site_name2id("frame_hang_site"), # end of frame where hanging takes place
+ frame_mount_site=self.sim.model.site_name2id(
+ "frame_mount_site"
+ ), # bottom of frame that needs to be inserted into base
+ frame_intersection_site=self.sim.model.site_name2id("frame_intersection_site"), # corner of frame
+ stand_mount_site=self.sim.model.site_name2id(
+ "stand_mount_site"
+ ), # where frame needs to be inserted into stand
+ )
+ if ("tip_size" in self.frame_args) and (self.frame_args["tip_size"] is not None):
+ self.obj_site_id["frame_tip_site"] = self.sim.model.site_name2id("frame_tip_site") # tip site for insertion
+
+ # Important geoms:
+ # stand_base - for checking that stand base is upright
+ # stand wall geoms - for checking rod insertion into stand
+ # tool hole geoms - for checking insertion
+ self.obj_geom_id = dict(
+ stand_base=self.sim.model.geom_name2id("stand_base"), # bottom of stand
+ )
+ for i in range(4):
+ self.obj_geom_id["stand_wall_{}".format(i)] = self.sim.model.geom_name2id("stand_wall{}".format(i))
+ for i in range(self.tool_args["ngeoms"]):
+ self.obj_geom_id["tool_hole1_hc_{}".format(i)] = self.sim.model.geom_name2id("tool_hole1_hc_{}".format(i))
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ pf = self.robots[0].robot_model.naming_prefix
+ modality = "object"
+
+ # for conversion to relative gripper frame
+ @sensor(modality=modality)
+ def world_pose_in_gripper(obs_cache):
+ return (
+ T.pose_inv(T.pose2mat((obs_cache[f"{pf}eef_pos"], obs_cache[f"{pf}eef_quat"])))
+ if f"{pf}eef_pos" in obs_cache and f"{pf}eef_quat" in obs_cache
+ else np.eye(4)
+ )
+
+ sensors = [world_pose_in_gripper]
+ names = ["world_pose_in_gripper"]
+ actives = [False]
+
+ # Add absolute and relative pose for each object
+ obj_names = ["base", "frame", "tool"]
+ query_names = ["stand_base", "frame_intersection_site", "tool"]
+ query_types = ["geom", "site", "body"]
+ for i in range(len(obj_names)):
+ obj_sensors, obj_sensor_names = self._create_obj_sensors(
+ obj_name=obj_names[i], modality=modality, query_name=query_names[i], query_type=query_types[i]
+ )
+ sensors += obj_sensors
+ names += obj_sensor_names
+ actives += [True] * len(obj_sensors)
+
+ # Key boolean checks
+ @sensor(modality=modality)
+ def frame_is_assembled(obs_cache):
+ return [float(self._check_frame_assembled())]
+
+ @sensor(modality=modality)
+ def tool_on_frame(obs_cache):
+ return [float(self._check_tool_on_frame())]
+
+ sensors += [frame_is_assembled, tool_on_frame]
+ names += [frame_is_assembled.__name__, tool_on_frame.__name__]
+ actives += [True, True]
+
+ # Create observables
+ for name, s, active in zip(names, sensors, actives):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ active=active,
+ )
+
+ return observables
+
+ def _create_obj_sensors(self, obj_name, modality="object", query_name=None, query_type="body"):
+ """
+ Helper function to create sensors for a given object. This is abstracted in a separate function call so that we
+ don't have local function naming collisions during the _setup_observables() call.
+
+ Args:
+ obj_name (str): Name of object to create sensors for (used for naming observations)
+ modality (str): Modality to assign to all sensors
+ query_name (str): Name to query mujoco for the pose attributes of this object - if None, use @obj_name
+ query_type (str): Either "body", "geom", or "site" - type of mujoco sensor that will be queried for pose
+
+ Returns:
+ 2-tuple:
+ sensors (list): Array of sensors for the given obj
+ names (list): array of corresponding observable names
+ """
+ if query_name is None:
+ query_name = obj_name
+
+ assert query_type in ["body", "geom", "site"]
+ if query_type == "body":
+ id_lookup = self.obj_body_id
+ pos_lookup = self.sim.data.body_xpos
+ mat_lookup = self.sim.data.body_xmat
+ elif query_type == "geom":
+ id_lookup = self.obj_geom_id
+ pos_lookup = self.sim.data.geom_xpos
+ mat_lookup = self.sim.data.geom_xmat
+ else:
+ id_lookup = self.obj_site_id
+ pos_lookup = self.sim.data.site_xpos
+ mat_lookup = self.sim.data.site_xmat
+
+ ### TODO: this was slightly modified from pick-place - do we want to move this into utils to share it? ###
+ pf = self.robots[0].robot_model.naming_prefix
+
+ @sensor(modality=modality)
+ def obj_pos(obs_cache):
+ return np.array(pos_lookup[id_lookup[query_name]])
+
+ @sensor(modality=modality)
+ def obj_quat(obs_cache):
+ return T.mat2quat(np.array(mat_lookup[id_lookup[query_name]]).reshape(3, 3))
+
+ @sensor(modality=modality)
+ def obj_to_eef_pos(obs_cache):
+ # Immediately return default value if cache is empty
+ if any(
+ [name not in obs_cache for name in [f"{obj_name}_pos", f"{obj_name}_quat", "world_pose_in_gripper"]]
+ ):
+ return np.zeros(3)
+ obj_pose = T.pose2mat((obs_cache[f"{obj_name}_pos"], obs_cache[f"{obj_name}_quat"]))
+ rel_pose = T.pose_in_A_to_pose_in_B(obj_pose, obs_cache["world_pose_in_gripper"])
+ rel_pos, rel_quat = T.mat2pose(rel_pose)
+ obs_cache[f"{obj_name}_to_{pf}eef_quat"] = rel_quat
+ return rel_pos
+
+ @sensor(modality=modality)
+ def obj_to_eef_quat(obs_cache):
+ return (
+ obs_cache[f"{obj_name}_to_{pf}eef_quat"] if f"{obj_name}_to_{pf}eef_quat" in obs_cache else np.zeros(4)
+ )
+
+ sensors = [obj_pos, obj_quat, obj_to_eef_pos, obj_to_eef_quat]
+ names = [f"{obj_name}_pos", f"{obj_name}_quat", f"{obj_name}_to_{pf}eef_pos", f"{obj_name}_to_{pf}eef_quat"]
+
+ return sensors, names
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ # Reset all object positions using initializer sampler if we're not directly loading from an xml
+ if not self.deterministic_reset:
+
+ # Sample from the placement initializer for all objects
+ object_placements = self.placement_initializer.sample()
+
+ # Loop through all objects and reset their positions
+ for obj_pos, obj_quat, obj in object_placements.values():
+ self.sim.data.set_joint_qpos(obj.joints[0], np.concatenate([np.array(obj_pos), np.array(obj_quat)]))
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualize gripper site proportional to the distance to the cube.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "grippers" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+
+ # Color the gripper visualization site according to its distance to the cube
+ if vis_settings["grippers"]:
+ self._visualize_gripper_to_target(gripper=self.robots[0].gripper, target=self.tool)
+
+ def _check_success(self):
+ """
+ Check if tool is hung on frame correctly and frame is assembled coorectly as well.
+
+ Returns:
+ bool: True if tool is hung on frame correctly
+ """
+ return self._check_frame_assembled() and self._check_tool_on_frame()
+
+ def _check_frame_assembled(self):
+ """
+ Check if the frame has been assembled correctly. This checks the following things:
+ (1) the base is upright
+ (2) the end of the hook frame is close enough to the base
+ (3) the hook frame is between the walls of the base
+ """
+
+ # position of base
+ base_pos = self.sim.data.geom_xpos[self.obj_geom_id["stand_base"]]
+
+ # check (1): the base is upright. Just take the vector between two locations on the base shaft, and check
+ # that the angle to the z-axis is small, by computing the angle between that unit vector and
+ # the z-axis. Recall that for two unit vectors, the arccosine of the dot product gives the angle.
+ vec_along_base_shaft = self.sim.data.site_xpos[self.obj_site_id["stand_mount_site"]] - base_pos
+ vec_along_base_shaft = vec_along_base_shaft / np.linalg.norm(vec_along_base_shaft)
+ angle_to_z_axis = np.abs(np.arccos(vec_along_base_shaft[2]))
+ base_shaft_is_vertical = angle_to_z_axis < np.pi / 18.0 # less than 10 degrees
+
+ # check (2): the end of the hook frame is close enough to the base. Just check the distance
+ if "frame_tip_site" in self.obj_site_id:
+ bottom_hook_pos = self.sim.data.site_xpos[self.obj_site_id["frame_tip_site"]]
+ else:
+ bottom_hook_pos = self.sim.data.site_xpos[self.obj_site_id["frame_mount_site"]]
+ insertion_dist = np.linalg.norm(bottom_hook_pos - base_pos)
+ # insertion_tolerance = (self.frame_args["frame_thickness"] / 2.)
+ insertion_tolerance = 0.05 # NOTE: this was manually tuned
+ bottom_is_close_enough = insertion_dist < insertion_tolerance
+
+ # check (3): the hook frame is in between the walls of the base. Take the geom positions of opposing base walls
+ # and check that they are on opposite sides of the line defined by the hook frame.
+
+ # normalized vector that points along the frame hook
+ hook_endpoint = self.sim.data.site_xpos[self.obj_site_id["frame_mount_site"]]
+ frame_hook_vec = self.sim.data.site_xpos[self.obj_site_id["frame_intersection_site"]] - hook_endpoint
+ frame_hook_length = np.linalg.norm(frame_hook_vec)
+ frame_hook_vec = frame_hook_vec / frame_hook_length
+
+ # geom wall position vectors relative to base position
+ geom_positions = [
+ self.sim.data.geom_xpos[self.obj_geom_id["stand_wall_{}".format(i)]] - hook_endpoint for i in range(4)
+ ]
+
+ # take cross product of each point against the line, and then dot the result to see if
+ # the sign is positive or negative. If it is positive, then they are on the same side
+ # (visualize with right-hand-rule to see this)
+ rod_is_between_stand_walls = all(
+ [
+ np.dot(np.cross(geom_positions[0], frame_hook_vec), np.cross(geom_positions[2], frame_hook_vec)) < 0,
+ np.dot(np.cross(geom_positions[1], frame_hook_vec), np.cross(geom_positions[3], frame_hook_vec)) < 0,
+ ]
+ )
+
+ return base_shaft_is_vertical and (bottom_is_close_enough and rod_is_between_stand_walls)
+
+ def _check_tool_on_frame(self):
+ """
+ Check if the tool has been hung on the frame correctly. This checks the following things:
+ (1) the robot is not touching the tool (it is hanging on its own)
+ (2) the tool hole is making contact with the frame hook
+ (3) the tool hole is close to the line defined by the frame hook
+ (4) either end of the tool hole are on opposite sides of the frame hook
+ (5) the tool hole is inserted far enough into the frame hook
+ """
+
+ # check (1): robot is not touching the tool
+ robot_grasp_geoms = [
+ self.robots[0].gripper.important_geoms["left_fingerpad"],
+ self.robots[0].gripper.important_geoms["right_fingerpad"],
+ ]
+ robot_and_tool_contact = False
+ for g_group in robot_grasp_geoms:
+ if check_contact(self.sim, g_group, self.tool.contact_geoms):
+ robot_and_tool_contact = True
+ break
+
+ # check (2): the tool hole is making contact with the frame hook
+ all_tool_hole_geoms = ["tool_hole1_hc_{}".format(i) for i in range(self.tool_args["ngeoms"])]
+ frame_hook_geom = "frame_horizontal_frame"
+ frame_and_tool_hole_contact = check_contact(self.sim, all_tool_hole_geoms, frame_hook_geom)
+
+ # check (3): compute distance from tool hole center to the line defined by the frame hook
+
+ # normalized vector that points along the frame hook
+ hook_endpoint = self.sim.data.site_xpos[self.obj_site_id["frame_hang_site"]]
+ frame_hook_vec = self.sim.data.site_xpos[self.obj_site_id["frame_intersection_site"]] - hook_endpoint
+ frame_hook_length = np.linalg.norm(frame_hook_vec)
+ frame_hook_vec = frame_hook_vec / frame_hook_length
+
+ # compute orthogonal projection of tool hole point to get distance to frame hook line
+ # (see https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Vector_formulation)
+ tool_hole_center = self.sim.data.site_xpos[self.obj_site_id["tool_hole1_center"]]
+ tool_hole_vec = tool_hole_center - hook_endpoint
+ tool_hole_dot = np.dot(tool_hole_vec, frame_hook_vec)
+ tool_hole_proj = tool_hole_dot * frame_hook_vec
+ tool_hole_ortho_proj = tool_hole_vec - tool_hole_proj
+ dist_to_frame_hook_line = np.linalg.norm(tool_hole_ortho_proj)
+
+ # distance needs to be less than the difference between the inner tool hole radius and the half-length of the frame hook box geom
+ tool_hole_is_close_enough = dist_to_frame_hook_line < (
+ self.tool_args["inner_radius_1"] - (self.frame_args["frame_thickness"] / 2.0)
+ )
+
+ # check (4): take two opposite geoms around the tool hole, and check that they are on opposite sides of the frame hook line
+ # to guarantee that insertion has taken place
+ g2_id = self.tool_args["ngeoms"] // 2 # get geom opposite geom 0
+ g1_pos = self.sim.data.geom_xpos[self.obj_geom_id["tool_hole1_hc_0"]]
+ g2_pos = self.sim.data.geom_xpos[self.obj_geom_id["tool_hole1_hc_{}".format(g2_id)]]
+
+ # take cross product of each point against the line, and then dot the result to see if
+ # the sign is positive or negative. If it is positive, then they are on the same side
+ # (visualize with right-hand-rule to see this)
+ g1_vec = g1_pos - hook_endpoint
+ g2_vec = g2_pos - hook_endpoint
+ tool_is_between_hook = np.dot(np.cross(g1_vec, frame_hook_vec), np.cross(g2_vec, frame_hook_vec)) < 0
+
+ # check (5): check if tool insertion is far enough - check this by computing normalized distance of projection along frame hook line.
+ # We ensure that it's at least 5% inserted along the length of the frame hook.
+ normalized_dist_along_frame_hook_line = tool_hole_dot / frame_hook_length
+ tool_is_inserted_far_enough = (normalized_dist_along_frame_hook_line > 0.05) and (
+ normalized_dist_along_frame_hook_line < 1.0
+ )
+
+ return all(
+ [
+ (not robot_and_tool_contact),
+ frame_and_tool_hole_contact,
+ tool_hole_is_close_enough,
+ tool_is_between_hook,
+ tool_is_inserted_far_enough,
+ ]
+ )
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_env.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5c99f8d68deaf5c83eaed5fe464a11a56a3caac
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_env.py
@@ -0,0 +1,136 @@
+import numpy as np
+
+from robosuite.environments.manipulation.manipulation_env import ManipulationEnv
+from robosuite.utils.robot_utils import check_bimanual
+from robosuite.utils.transform_utils import mat2quat
+
+
+class TwoArmEnv(ManipulationEnv):
+ """
+ A manipulation environment intended for two robot arms.
+ """
+
+ def _check_robot_configuration(self, robots):
+ """
+ Sanity check to make sure the inputted robots and configuration is acceptable
+
+ Args:
+ robots (str or list of str): Robots to instantiate within this env
+ """
+ super()._check_robot_configuration(robots)
+ robots = robots if type(robots) == list or type(robots) == tuple else [robots]
+ # If default config is used, set env_configuration accordingly
+ if self.env_configuration == "default":
+ self.env_configuration = "bimanual" if check_bimanual(robots[0]) else "single-arm-opposed"
+
+ if self.env_configuration == "single-arm-opposed" or self.env_configuration == "single-arm-parallel":
+ # Specifically two robots should be inputted!
+ is_bimanual = False
+ if type(robots) is not list or len(robots) != 2:
+ raise ValueError(
+ "Error: Exactly two single-armed robots should be inputted " "for this task configuration!"
+ )
+ elif self.env_configuration == "bimanual":
+ is_bimanual = True
+ # Specifically one robot should be inputted!
+ if type(robots) is list and len(robots) != 1:
+ raise ValueError("Error: Exactly one bimanual robot should be inputted " "for this task configuration!")
+ else:
+ # This is an unknown env configuration, print error
+ raise ValueError(
+ "Error: Unknown environment configuration received. Only 'bimanual',"
+ "'single-arm-parallel', and 'single-arm-opposed' are supported. Got: {}".format(self.env_configuration)
+ )
+
+ # Lastly, check to make sure all inputted robot names are of their correct type (bimanual / not bimanual)
+ for robot in robots:
+ if check_bimanual(robot) != is_bimanual:
+ raise ValueError(
+ "Error: For {} configuration, expected bimanual check to return {}; "
+ "instead, got {}.".format(self.env_configuration, is_bimanual, check_bimanual(robot))
+ )
+
+ @property
+ def _eef0_xpos(self):
+ """
+ Grab the position of Robot 0's end effector.
+
+ Returns:
+ np.array: (x,y,z) position of EEF0
+ """
+ if self.env_configuration == "bimanual":
+ return np.array(self.sim.data.site_xpos[self.robots[0].eef_site_id["right"]])
+ else:
+ return np.array(self.sim.data.site_xpos[self.robots[0].eef_site_id])
+
+ @property
+ def _eef1_xpos(self):
+ """
+ Grab the position of Robot 1's end effector.
+
+ Returns:
+ np.array: (x,y,z) position of EEF1
+ """
+ if self.env_configuration == "bimanual":
+ return np.array(self.sim.data.site_xpos[self.robots[0].eef_site_id["left"]])
+ else:
+ return np.array(self.sim.data.site_xpos[self.robots[1].eef_site_id])
+
+ @property
+ def _eef0_xmat(self):
+ """
+ End Effector 0 orientation as a rotation matrix
+ Note that this draws the orientation from the "ee" site, NOT the gripper site, since the gripper
+ orientations are inconsistent!
+
+ Returns:
+ np.array: (3,3) orientation matrix for EEF0
+ """
+ pf = self.robots[0].gripper.naming_prefix
+
+ if self.env_configuration == "bimanual":
+ return np.array(self.sim.data.site_xmat[self.sim.model.site_name2id(pf + "right_grip_site")]).reshape(3, 3)
+
+ else:
+ return np.array(self.sim.data.site_xmat[self.sim.model.site_name2id(pf + "grip_site")]).reshape(3, 3)
+
+ @property
+ def _eef1_xmat(self):
+ """
+ End Effector 1 orientation as a rotation matrix
+ Note that this draws the orientation from the "ee" site, NOT the gripper site, since the gripper
+ orientations are inconsistent!
+
+ Returns:
+ np.array: (3,3) orientation matrix for EEF1
+ """
+ if self.env_configuration == "bimanual":
+ pf = self.robots[0].gripper.naming_prefix
+ return np.array(self.sim.data.site_xmat[self.sim.model.site_name2id(pf + "left_grip_site")]).reshape(3, 3)
+ else:
+ pf = self.robots[1].gripper.naming_prefix
+ return np.array(self.sim.data.site_xmat[self.sim.model.site_name2id(pf + "grip_site")]).reshape(3, 3)
+
+ @property
+ def _eef0_xquat(self):
+ """
+ End Effector 0 orientation as a (x,y,z,w) quaternion
+ Note that this draws the orientation from the "ee" site, NOT the gripper site, since the gripper
+ orientations are inconsistent!
+
+ Returns:
+ np.array: (x,y,z,w) quaternion for EEF0
+ """
+ return mat2quat(self._eef0_xmat)
+
+ @property
+ def _eef1_xquat(self):
+ """
+ End Effector 1 orientation as a (x,y,z,w) quaternion
+ Note that this draws the orientation from the "ee" site, NOT the gripper site, since the gripper
+ orientations are inconsistent!
+
+ Returns:
+ np.array: (x,y,z,w) quaternion for EEF1
+ """
+ return mat2quat(self._eef1_xmat)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_handover.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_handover.py
new file mode 100644
index 0000000000000000000000000000000000000000..db0d5f94d0b75e59c85128bdcc4d5209d5c84547
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_handover.py
@@ -0,0 +1,617 @@
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.environments.manipulation.two_arm_env import TwoArmEnv
+from robosuite.models.arenas import TableArena
+from robosuite.models.objects import HammerObject
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import UniformRandomSampler
+
+
+class TwoArmHandover(TwoArmEnv):
+ """
+ This class corresponds to the handover task for two robot arms.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be either 2 single single-arm robots or 1 bimanual robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment. Can be either:
+
+ :`'bimanual'`: Only applicable for bimanual robot setups. Sets up the (single) bimanual robot on the -x
+ side of the table
+ :`'single-arm-parallel'`: Only applicable for multi single arm setups. Sets up the (two) single armed
+ robots next to each other on the -x side of the table
+ :`'single-arm-opposed'`: Only applicable for multi single arm setups. Sets up the (two) single armed
+ robots opposed from each others on the opposite +/-y sides of the table.
+
+ Note that "default" corresponds to either "bimanual" if a bimanual robot is used or "single-arm-opposed" if two
+ single-arm robots are used.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ prehensile (bool): If true, handover object starts on the table. Else, the object starts in Arm0's gripper
+
+ table_full_size (3-tuple): x, y, and z dimensions of the table.
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ the table.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ placement_initializer (ObjectPositionSampler): if provided, will
+ be used to place objects on every reset, else a UniformRandomSampler
+ is used by default.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ ValueError: [Invalid number of robots specified]
+ ValueError: [Invalid env configuration]
+ ValueError: [Invalid robots for specified env configuration]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ prehensile=True,
+ table_full_size=(0.8, 1.2, 0.05),
+ table_friction=(1.0, 5e-3, 1e-4),
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ placement_initializer=None,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # Task settings
+ self.prehensile = prehensile
+
+ # settings for table top
+ self.table_full_size = table_full_size
+ self.table_true_size = list(table_full_size)
+ self.table_true_size[1] *= 0.25 # true size will only be partially wide
+ self.table_friction = table_friction
+ self.table_offset = [0, self.table_full_size[1] * (-3 / 8), 0.8]
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+ self.height_threshold = 0.1 # threshold above the table surface which the hammer is considered lifted
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ # object placement initializer
+ self.placement_initializer = placement_initializer
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of 2.0 is provided when only Arm 1 is gripping the handle and has the handle
+ lifted above a certain threshold
+
+ Un-normalized max-wise components if using reward shaping:
+
+ - Arm0 Reaching: (1) in [0, 0.25] proportional to the distance between Arm 0 and the handle
+ - Arm0 Grasping: (2) in {0, 0.5}, nonzero if Arm 0 is gripping the hammer (any part).
+ - Arm0 Lifting: (3) in {0, 1.0}, nonzero if Arm 0 lifts the handle from the table past a certain threshold
+ - Arm0 Hovering: (4) in {0, [1.0, 1.25]}, nonzero only if Arm0 is actively lifting the hammer, and is
+ proportional to the distance between the handle and Arm 1
+ conditioned on the handle being lifted from the table and being grasped by Arm 0
+ - Mutual Grasping: (5) in {0, 1.5}, nonzero if both Arm 0 and Arm 1 are gripping the hammer (Arm 1 must be
+ gripping the handle) while lifted above the table
+ - Handover: (6) in {0, 2.0}, nonzero when only Arm 1 is gripping the handle and has the handle
+ lifted above the table
+
+ Note that the final reward is normalized and scaled by reward_scale / 2.0 as
+ well so that the max score is equal to reward_scale
+
+ Args:
+ action (np array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ # Initialize reward
+ reward = 0
+
+ # use a shaping reward if specified
+ if self.reward_shaping:
+ # Grab relevant parameters
+ arm0_grasp_any, arm1_grasp_handle, hammer_height, table_height = self._get_task_info()
+ # First, we'll consider the cases if the hammer is lifted above the threshold (step 3 - 6)
+ if hammer_height - table_height > self.height_threshold:
+ # Split cases depending on whether arm1 is currently grasping the handle or not
+ if arm1_grasp_handle:
+ # Check if arm0 is grasping
+ if arm0_grasp_any:
+ # This is step 5
+ reward = 1.5
+ else:
+ # This is step 6 (completed task!)
+ reward = 2.0
+ # This is the case where only arm0 is grasping (step 2-3)
+ else:
+ reward = 1.0
+ # Add in up to 0.25 based on distance between handle and arm1
+ dist = np.linalg.norm(self._gripper_1_to_handle)
+ reaching_reward = 0.25 * (1 - np.tanh(1.0 * dist))
+ reward += reaching_reward
+ # Else, the hammer is still on the ground ):
+ else:
+ # Split cases depending on whether arm0 is currently grasping the handle or not
+ if arm0_grasp_any:
+ # This is step 2
+ reward = 0.5
+ else:
+ # This is step 1, we want to encourage arm0 to reach for the handle
+ dist = np.linalg.norm(self._gripper_0_to_handle)
+ reaching_reward = 0.25 * (1 - np.tanh(1.0 * dist))
+ reward = reaching_reward
+
+ # Else this is the sparse reward setting
+ else:
+ # Provide reward if only Arm 1 is grasping the hammer and the handle lifted above the pre-defined threshold
+ if self._check_success():
+ reward = 2.0
+
+ if self.reward_scale is not None:
+ reward *= self.reward_scale / 2.0
+
+ return reward
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose(s) accordingly
+ if self.env_configuration == "bimanual":
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+ else:
+ if self.env_configuration == "single-arm-opposed":
+ # Set up robots facing towards each other by rotating them from their default position
+ for robot, rotation, offset in zip(self.robots, (np.pi / 2, -np.pi / 2), (-0.25, 0.25)):
+ xpos = robot.robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ rot = np.array((0, 0, rotation))
+ xpos = T.euler2mat(rot) @ np.array(xpos)
+ xpos += np.array((0, offset, 0))
+ robot.robot_model.set_base_xpos(xpos)
+ robot.robot_model.set_base_ori(rot)
+ else: # "single-arm-parallel" configuration setting
+ # Set up robots parallel to each other but offset from the center
+ for robot, offset in zip(self.robots, (-0.6, 0.6)):
+ xpos = robot.robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ xpos = np.array(xpos) + np.array((0, offset, 0))
+ robot.robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = TableArena(
+ table_full_size=self.table_true_size, table_friction=self.table_friction, table_offset=self.table_offset
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # Modify default agentview camera
+ mujoco_arena.set_camera(
+ camera_name="agentview",
+ pos=[0.8894354364730311, -3.481824231498976e-08, 1.7383813133506494],
+ quat=[0.6530981063842773, 0.2710406184196472, 0.27104079723358154, 0.6530979871749878],
+ )
+
+ # initialize objects of interest
+ self.hammer = HammerObject(name="hammer")
+
+ # Create placement initializer
+ if self.placement_initializer is not None:
+ self.placement_initializer.reset()
+ self.placement_initializer.add_objects(self.hammer)
+ else:
+ # Set rotation about y-axis if hammer starts on table else rotate about z if it starts in gripper
+ rotation_axis = "y" if self.prehensile else "z"
+ self.placement_initializer = UniformRandomSampler(
+ name="ObjectSampler",
+ mujoco_objects=self.hammer,
+ x_range=[-0.1, 0.1],
+ y_range=[-0.05, 0.05],
+ rotation=None,
+ rotation_axis=rotation_axis,
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=True,
+ reference_pos=self.table_offset,
+ )
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ mujoco_objects=self.hammer,
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Hammer object references from this env
+ self.hammer_body_id = self.sim.model.body_name2id(self.hammer.root_body)
+ self.hammer_handle_geom_id = self.sim.model.geom_name2id(self.hammer.handle_geoms[0])
+
+ # General env references
+ self.table_top_id = self.sim.model.site_name2id("table_top")
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ if self.env_configuration == "bimanual":
+ pf0 = self.robots[0].robot_model.naming_prefix + "right_"
+ pf1 = self.robots[0].robot_model.naming_prefix + "left_"
+ else:
+ pf0 = self.robots[0].robot_model.naming_prefix
+ pf1 = self.robots[1].robot_model.naming_prefix
+ modality = "object"
+
+ # position and rotation of hammer
+ @sensor(modality=modality)
+ def hammer_pos(obs_cache):
+ return np.array(self._hammer_pos)
+
+ @sensor(modality=modality)
+ def hammer_quat(obs_cache):
+ return np.array(self._hammer_quat)
+
+ @sensor(modality=modality)
+ def handle_xpos(obs_cache):
+ return np.array(self._handle_xpos)
+
+ @sensor(modality=modality)
+ def gripper0_to_handle(obs_cache):
+ return (
+ obs_cache["handle_xpos"] - obs_cache[f"{pf0}eef_pos"]
+ if "handle_xpos" in obs_cache and f"{pf0}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ @sensor(modality=modality)
+ def gripper1_to_handle(obs_cache):
+ return (
+ obs_cache["handle_xpos"] - obs_cache[f"{pf1}eef_pos"]
+ if "handle_xpos" in obs_cache and f"{pf1}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ sensors = [hammer_pos, hammer_quat, handle_xpos, gripper0_to_handle, gripper1_to_handle]
+ names = [s.__name__ for s in sensors]
+
+ # Create observables
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ # Reset all object positions using initializer sampler if we're not directly loading from an xml
+ if not self.deterministic_reset:
+
+ # Sample from the placement initializer for all objects
+ object_placements = self.placement_initializer.sample()
+
+ # Loop through all objects and reset their positions
+ for obj_pos, obj_quat, obj in object_placements.values():
+ # If prehensile, set the object normally
+ if self.prehensile:
+ self.sim.data.set_joint_qpos(obj.joints[0], np.concatenate([np.array(obj_pos), np.array(obj_quat)]))
+ # Else, set the object in the hand of the robot and loop a few steps to guarantee the robot is grasping
+ # the object initially
+ else:
+ eef_rot_quat = T.mat2quat(T.euler2mat([np.pi - T.mat2euler(self._eef0_xmat)[2], 0, 0]))
+ obj_quat = T.quat_multiply(obj_quat, eef_rot_quat)
+ for j in range(100):
+ # Set object in hand
+ self.sim.data.set_joint_qpos(
+ obj.joints[0], np.concatenate([self._eef0_xpos, np.array(obj_quat)])
+ )
+ # Close gripper (action = 1) and prevent arm from moving
+ if self.env_configuration == "bimanual":
+ # Execute no-op action with gravity compensation
+ torques = np.concatenate(
+ [
+ self.robots[0].controller["right"].torque_compensation,
+ self.robots[0].controller["left"].torque_compensation,
+ ]
+ )
+ self.sim.data.ctrl[self.robots[0]._ref_joint_actuator_indexes] = torques
+ # Execute gripper action
+ self.robots[0].grip_action(gripper=self.robots[0].gripper["right"], gripper_action=[1])
+ else:
+ # Execute no-op action with gravity compensation
+ self.sim.data.ctrl[self.robots[0]._ref_joint_actuator_indexes] = self.robots[
+ 0
+ ].controller.torque_compensation
+ self.sim.data.ctrl[self.robots[1]._ref_joint_actuator_indexes] = self.robots[
+ 1
+ ].controller.torque_compensation
+ # Execute gripper action
+ self.robots[0].grip_action(gripper=self.robots[0].gripper, gripper_action=[1])
+ # Take forward step
+ self.sim.step()
+
+ def _get_task_info(self):
+ """
+ Helper function that grabs the current relevant locations of objects of interest within the environment
+
+ Returns:
+ 4-tuple:
+
+ - (bool) True if Arm0 is grasping any part of the hammer
+ - (bool) True if Arm1 is grasping the hammer handle
+ - (float) Height of the hammer body
+ - (float) Height of the table surface
+ """
+ # Get height of hammer and table and define height threshold
+ hammer_angle_offset = (self.hammer.handle_length / 2 + 2 * self.hammer.head_halfsize) * np.sin(
+ self._hammer_angle
+ )
+ hammer_height = (
+ self.sim.data.geom_xpos[self.hammer_handle_geom_id][2] - self.hammer.top_offset[2] - hammer_angle_offset
+ )
+ table_height = self.sim.data.site_xpos[self.table_top_id][2]
+
+ # Check if any Arm's gripper is grasping the hammer handle
+ (g0, g1) = (
+ (self.robots[0].gripper["right"], self.robots[0].gripper["left"])
+ if self.env_configuration == "bimanual"
+ else (self.robots[0].gripper, self.robots[1].gripper)
+ )
+ arm0_grasp_any = self._check_grasp(gripper=g0, object_geoms=self.hammer)
+ arm1_grasp_handle = self._check_grasp(gripper=g1, object_geoms=self.hammer.handle_geoms)
+
+ # Return all relevant values
+ return arm0_grasp_any, arm1_grasp_handle, hammer_height, table_height
+
+ def _check_success(self):
+ """
+ Check if hammer is successfully handed off
+
+ Returns:
+ bool: True if handover has been completed
+ """
+ # Grab relevant params
+ arm0_grasp_any, arm1_grasp_handle, hammer_height, table_height = self._get_task_info()
+ return (
+ True
+ if arm1_grasp_handle and not arm0_grasp_any and hammer_height - table_height > self.height_threshold
+ else False
+ )
+
+ @property
+ def _handle_xpos(self):
+ """
+ Grab the position of the hammer handle.
+
+ Returns:
+ np.array: (x,y,z) position of handle
+ """
+ return self.sim.data.geom_xpos[self.hammer_handle_geom_id]
+
+ @property
+ def _hammer_pos(self):
+ """
+ Grab the position of the hammer body.
+
+ Returns:
+ np.array: (x,y,z) position of body
+ """
+ return np.array(self.sim.data.body_xpos[self.hammer_body_id])
+
+ @property
+ def _hammer_quat(self):
+ """
+ Grab the orientation of the hammer body.
+
+ Returns:
+ np.array: (x,y,z,w) quaternion of the hammer body
+ """
+ return T.convert_quat(self.sim.data.body_xquat[self.hammer_body_id], to="xyzw")
+
+ @property
+ def _hammer_angle(self):
+ """
+ Calculate the angle of hammer with the ground, relative to it resting horizontally
+
+ Returns:
+ float: angle in radians
+ """
+ mat = T.quat2mat(self._hammer_quat)
+ z_unit = [0, 0, 1]
+ z_rotated = np.matmul(mat, z_unit)
+ return np.pi / 2 - np.arccos(np.dot(z_unit, z_rotated))
+
+ @property
+ def _gripper_0_to_handle(self):
+ """
+ Calculate vector from the left gripper to the hammer handle.
+
+ Returns:
+ np.array: (dx,dy,dz) distance vector between handle and EEF0
+ """
+ return self._handle_xpos - self._eef0_xpos
+
+ @property
+ def _gripper_1_to_handle(self):
+ """
+ Calculate vector from the right gripper to the hammer handle.
+
+ Returns:
+ np.array: (dx,dy,dz) distance vector between handle and EEF1
+ """
+ return self._handle_xpos - self._eef1_xpos
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_lift.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_lift.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba3acf1599aaab948e7e10a201d2a24e6082189d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_lift.py
@@ -0,0 +1,545 @@
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.environments.manipulation.two_arm_env import TwoArmEnv
+from robosuite.models.arenas import TableArena
+from robosuite.models.objects import PotWithHandlesObject
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import UniformRandomSampler
+
+
+class TwoArmLift(TwoArmEnv):
+ """
+ This class corresponds to the lifting task for two robot arms.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be either 2 single single-arm robots or 1 bimanual robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment. Can be either:
+
+ :`'bimanual'`: Only applicable for bimanual robot setups. Sets up the (single) bimanual robot on the -x
+ side of the table
+ :`'single-arm-parallel'`: Only applicable for multi single arm setups. Sets up the (two) single armed
+ robots next to each other on the -x side of the table
+ :`'single-arm-opposed'`: Only applicable for multi single arm setups. Sets up the (two) single armed
+ robots opposed from each others on the opposite +/-y sides of the table.
+
+ Note that "default" corresponds to either "bimanual" if a bimanual robot is used or "single-arm-opposed" if two
+ single-arm robots are used.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ table_full_size (3-tuple): x, y, and z dimensions of the table.
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ the table.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ placement_initializer (ObjectPositionSampler): if provided, will
+ be used to place objects on every reset, else a UniformRandomSampler
+ is used by default.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ ValueError: [Invalid number of robots specified]
+ ValueError: [Invalid env configuration]
+ ValueError: [Invalid robots for specified env configuration]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1.0, 5e-3, 1e-4),
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ placement_initializer=None,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # settings for table top
+ self.table_full_size = table_full_size
+ self.table_friction = table_friction
+ self.table_offset = np.array((0, 0, 0.8))
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ # object placement initializer
+ self.placement_initializer = placement_initializer
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of 3.0 is provided if the pot is lifted and is parallel within 30 deg to the table
+
+ Un-normalized summed components if using reward shaping:
+
+ - Reaching: in [0, 0.5], per-arm component that is proportional to the distance between each arm and its
+ respective pot handle, and exactly 0.5 when grasping the handle
+ - Note that the agent only gets the lifting reward when flipping no more than 30 degrees.
+ - Grasping: in {0, 0.25}, binary per-arm component awarded if the gripper is grasping its correct handle
+ - Lifting: in [0, 1.5], proportional to the pot's height above the table, and capped at a certain threshold
+
+ Note that the final reward is normalized and scaled by reward_scale / 3.0 as
+ well so that the max score is equal to reward_scale
+
+ Args:
+ action (np array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ reward = 0
+
+ # check if the pot is tilted more than 30 degrees
+ mat = T.quat2mat(self._pot_quat)
+ z_unit = [0, 0, 1]
+ z_rotated = np.matmul(mat, z_unit)
+ cos_z = np.dot(z_unit, z_rotated)
+ cos_30 = np.cos(np.pi / 6)
+ direction_coef = 1 if cos_z >= cos_30 else 0
+
+ # check for goal completion: cube is higher than the table top above a margin
+ if self._check_success():
+ reward = 3.0 * direction_coef
+
+ # use a shaping reward
+ elif self.reward_shaping:
+ # lifting reward
+ pot_bottom_height = self.sim.data.site_xpos[self.pot_center_id][2] - self.pot.top_offset[2]
+ table_height = self.sim.data.site_xpos[self.table_top_id][2]
+ elevation = pot_bottom_height - table_height
+ r_lift = min(max(elevation - 0.05, 0), 0.15)
+ reward += 10.0 * direction_coef * r_lift
+
+ _gripper0_to_handle0 = self._gripper0_to_handle0
+ _gripper1_to_handle1 = self._gripper1_to_handle1
+
+ # gh stands for gripper-handle
+ # When grippers are far away, tell them to be closer
+
+ # Get contacts
+ (g0, g1) = (
+ (self.robots[0].gripper["right"], self.robots[0].gripper["left"])
+ if self.env_configuration == "bimanual"
+ else (self.robots[0].gripper, self.robots[1].gripper)
+ )
+
+ _g0h_dist = np.linalg.norm(_gripper0_to_handle0)
+ _g1h_dist = np.linalg.norm(_gripper1_to_handle1)
+
+ # Grasping reward
+ if self._check_grasp(gripper=g0, object_geoms=self.pot.handle0_geoms):
+ reward += 0.25
+ # Reaching reward
+ reward += 0.5 * (1 - np.tanh(10.0 * _g0h_dist))
+
+ # Grasping reward
+ if self._check_grasp(gripper=g1, object_geoms=self.pot.handle1_geoms):
+ reward += 0.25
+ # Reaching reward
+ reward += 0.5 * (1 - np.tanh(10.0 * _g1h_dist))
+
+ if self.reward_scale is not None:
+ reward *= self.reward_scale / 3.0
+
+ return reward
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose(s) accordingly
+ if self.env_configuration == "bimanual":
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+ else:
+ if self.env_configuration == "single-arm-opposed":
+ # Set up robots facing towards each other by rotating them from their default position
+ for robot, rotation in zip(self.robots, (np.pi / 2, -np.pi / 2)):
+ xpos = robot.robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ rot = np.array((0, 0, rotation))
+ xpos = T.euler2mat(rot) @ np.array(xpos)
+ robot.robot_model.set_base_xpos(xpos)
+ robot.robot_model.set_base_ori(rot)
+ else: # "single-arm-parallel" configuration setting
+ # Set up robots parallel to each other but offset from the center
+ for robot, offset in zip(self.robots, (-0.25, 0.25)):
+ xpos = robot.robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ xpos = np.array(xpos) + np.array((0, offset, 0))
+ robot.robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = TableArena(
+ table_full_size=self.table_full_size,
+ table_friction=self.table_friction,
+ table_offset=self.table_offset,
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # initialize objects of interest
+ self.pot = PotWithHandlesObject(name="pot")
+
+ # Create placement initializer
+ if self.placement_initializer is not None:
+ self.placement_initializer.reset()
+ self.placement_initializer.add_objects(self.pot)
+ else:
+ self.placement_initializer = UniformRandomSampler(
+ name="ObjectSampler",
+ mujoco_objects=self.pot,
+ x_range=[-0.03, 0.03],
+ y_range=[-0.03, 0.03],
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=True,
+ reference_pos=self.table_offset,
+ rotation=(np.pi + -np.pi / 3, np.pi + np.pi / 3),
+ )
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ mujoco_objects=self.pot,
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Additional object references from this env
+ self.pot_body_id = self.sim.model.body_name2id(self.pot.root_body)
+ self.handle0_site_id = self.sim.model.site_name2id(self.pot.important_sites["handle0"])
+ self.handle1_site_id = self.sim.model.site_name2id(self.pot.important_sites["handle1"])
+ self.table_top_id = self.sim.model.site_name2id("table_top")
+ self.pot_center_id = self.sim.model.site_name2id(self.pot.important_sites["center"])
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ if self.env_configuration == "bimanual":
+ pf0 = self.robots[0].robot_model.naming_prefix + "right_"
+ pf1 = self.robots[0].robot_model.naming_prefix + "left_"
+ else:
+ pf0 = self.robots[0].robot_model.naming_prefix
+ pf1 = self.robots[1].robot_model.naming_prefix
+ modality = "object"
+
+ # position and rotation of object
+
+ @sensor(modality=modality)
+ def pot_pos(obs_cache):
+ return np.array(self.sim.data.body_xpos[self.pot_body_id])
+
+ @sensor(modality=modality)
+ def pot_quat(obs_cache):
+ return T.convert_quat(self.sim.data.body_xquat[self.pot_body_id], to="xyzw")
+
+ @sensor(modality=modality)
+ def handle0_xpos(obs_cache):
+ return np.array(self._handle0_xpos)
+
+ @sensor(modality=modality)
+ def handle1_xpos(obs_cache):
+ return np.array(self._handle1_xpos)
+
+ @sensor(modality=modality)
+ def gripper0_to_handle0(obs_cache):
+ return (
+ obs_cache["handle0_xpos"] - obs_cache[f"{pf0}eef_pos"]
+ if "handle0_xpos" in obs_cache and f"{pf0}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ @sensor(modality=modality)
+ def gripper1_to_handle1(obs_cache):
+ return (
+ obs_cache["handle1_xpos"] - obs_cache[f"{pf1}eef_pos"]
+ if "handle1_xpos" in obs_cache and f"{pf1}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ sensors = [pot_pos, pot_quat, handle0_xpos, handle1_xpos, gripper0_to_handle0, gripper1_to_handle1]
+ names = [s.__name__ for s in sensors]
+
+ # Create observables
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ # Reset all object positions using initializer sampler if we're not directly loading from an xml
+ if not self.deterministic_reset:
+
+ # Sample from the placement initializer for all objects
+ object_placements = self.placement_initializer.sample()
+
+ # Loop through all objects and reset their positions
+ for obj_pos, obj_quat, obj in object_placements.values():
+ self.sim.data.set_joint_qpos(obj.joints[0], np.concatenate([np.array(obj_pos), np.array(obj_quat)]))
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualize gripper site proportional to the distance to each handle.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "grippers" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+
+ # Color the gripper visualization site according to its distance to each handle
+ if vis_settings["grippers"]:
+ handles = [self.pot.important_sites[f"handle{i}"] for i in range(2)]
+ grippers = (
+ [self.robots[0].gripper[arm] for arm in self.robots[0].arms]
+ if self.env_configuration == "bimanual"
+ else [robot.gripper for robot in self.robots]
+ )
+ for gripper, handle in zip(grippers, handles):
+ self._visualize_gripper_to_target(gripper=gripper, target=handle, target_type="site")
+
+ def _check_success(self):
+ """
+ Check if pot is successfully lifted
+
+ Returns:
+ bool: True if pot is lifted
+ """
+ pot_bottom_height = self.sim.data.site_xpos[self.pot_center_id][2] - self.pot.top_offset[2]
+ table_height = self.sim.data.site_xpos[self.table_top_id][2]
+
+ # cube is higher than the table top above a margin
+ return pot_bottom_height > table_height + 0.10
+
+ @property
+ def _handle0_xpos(self):
+ """
+ Grab the position of the left (blue) hammer handle.
+
+ Returns:
+ np.array: (x,y,z) position of handle
+ """
+ return self.sim.data.site_xpos[self.handle0_site_id]
+
+ @property
+ def _handle1_xpos(self):
+ """
+ Grab the position of the right (green) hammer handle.
+
+ Returns:
+ np.array: (x,y,z) position of handle
+ """
+ return self.sim.data.site_xpos[self.handle1_site_id]
+
+ @property
+ def _pot_quat(self):
+ """
+ Grab the orientation of the pot body.
+
+ Returns:
+ np.array: (x,y,z,w) quaternion of the pot body
+ """
+ return T.convert_quat(self.sim.data.body_xquat[self.pot_body_id], to="xyzw")
+
+ @property
+ def _gripper0_to_handle0(self):
+ """
+ Calculate vector from the left gripper to the left pot handle.
+
+ Returns:
+ np.array: (dx,dy,dz) distance vector between handle and EEF0
+ """
+ return self._handle0_xpos - self._eef0_xpos
+
+ @property
+ def _gripper1_to_handle1(self):
+ """
+ Calculate vector from the right gripper to the right pot handle.
+
+ Returns:
+ np.array: (dx,dy,dz) distance vector between handle and EEF0
+ """
+ return self._handle1_xpos - self._eef1_xpos
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_peg_in_hole.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_peg_in_hole.py
new file mode 100644
index 0000000000000000000000000000000000000000..02f46b2fc814b538f5627087508a6656de8144ce
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_peg_in_hole.py
@@ -0,0 +1,518 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.environments.manipulation.two_arm_env import TwoArmEnv
+from robosuite.models.arenas import EmptyArena
+from robosuite.models.objects import CylinderObject, PlateWithHoleObject
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.mjcf_utils import CustomMaterial, array_to_string, find_elements
+from robosuite.utils.observables import Observable, sensor
+
+
+class TwoArmPegInHole(TwoArmEnv):
+ """
+ This class corresponds to the peg-in-hole task for two robot arms.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be either 2 single single-arm robots or 1 bimanual robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment. Can be either:
+
+ :`'bimanual'`: Only applicable for bimanual robot setups. Sets up the (single) bimanual robot on the -x
+ side of the table
+ :`'single-arm-parallel'`: Only applicable for multi single arm setups. Sets up the (two) single armed
+ robots next to each other on the -x side of the table
+ :`'single-arm-opposed'`: Only applicable for multi single arm setups. Sets up the (two) single armed
+ robots opposed from each others on the opposite +/-y sides of the table.
+
+ Note that "default" corresponds to either "bimanual" if a bimanual robot is used or "single-arm-opposed" if two
+ single-arm robots are used.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate gripper models from gripper factory.
+ For this environment, setting a value other than the default (None) will raise an AssertionError, as
+ this environment is not meant to be used with any gripper at all.
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ use_camera_obs (bool or list of bool): if True, every observation for a specific robot includes a rendered
+ image. Should either be single bool if camera obs value is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ peg_radius (2-tuple): low and high limits of the (uniformly sampled)
+ radius of the peg
+
+ peg_length (float): length of the peg
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ AssertionError: [Gripper specified]
+ ValueError: [Invalid number of robots specified]
+ ValueError: [Invalid env configuration]
+ ValueError: [Invalid robots for specified env configuration]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types=None,
+ initialization_noise="default",
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ peg_radius=(0.015, 0.03),
+ peg_length=0.13,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # Assert that the gripper type is None
+ assert gripper_types is None, "Tried to specify gripper other than None in TwoArmPegInHole environment!"
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ # Save peg specs
+ self.peg_radius = peg_radius
+ self.peg_length = peg_length
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of 5.0 is provided if the peg is inside the plate's hole
+ - Note that we enforce that it's inside at an appropriate angle (cos(theta) > 0.95).
+
+ Un-normalized summed components if using reward shaping:
+
+ - Reaching: in [0, 1], to encourage the arms to approach each other
+ - Perpendicular Distance: in [0,1], to encourage the arms to approach each other
+ - Parallel Distance: in [0,1], to encourage the arms to approach each other
+ - Alignment: in [0, 1], to encourage having the right orientation between the peg and hole.
+ - Placement: in {0, 1}, nonzero if the peg is in the hole with a relatively correct alignment
+
+ Note that the final reward is normalized and scaled by reward_scale / 5.0 as
+ well so that the max score is equal to reward_scale
+
+ """
+ reward = 0
+
+ # Right location and angle
+ if self._check_success():
+ reward = 1.0
+
+ # use a shaping reward
+ if self.reward_shaping:
+ # Grab relevant values
+ t, d, cos = self._compute_orientation()
+ # reaching reward
+ hole_pos = self.sim.data.body_xpos[self.hole_body_id]
+ gripper_site_pos = self.sim.data.body_xpos[self.peg_body_id]
+ dist = np.linalg.norm(gripper_site_pos - hole_pos)
+ reaching_reward = 1 - np.tanh(1.0 * dist)
+ reward += reaching_reward
+
+ # Orientation reward
+ reward += 1 - np.tanh(d)
+ reward += 1 - np.tanh(np.abs(t))
+ reward += cos
+
+ # if we're not reward shaping, scale sparse reward so that the max reward is identical to its dense version
+ else:
+ reward *= 5.0
+
+ if self.reward_scale is not None:
+ reward *= self.reward_scale / 5.0
+
+ return reward
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose(s) accordingly
+ if self.env_configuration == "bimanual":
+ xpos = self.robots[0].robot_model.base_xpos_offset["empty"]
+ self.robots[0].robot_model.set_base_xpos(xpos)
+ else:
+ if self.env_configuration == "single-arm-opposed":
+ # Set up robots facing towards each other by rotating them from their default position
+ for robot, rotation in zip(self.robots, (np.pi / 2, -np.pi / 2)):
+ xpos = robot.robot_model.base_xpos_offset["empty"]
+ rot = np.array((0, 0, rotation))
+ xpos = T.euler2mat(rot) @ np.array(xpos)
+ robot.robot_model.set_base_xpos(xpos)
+ robot.robot_model.set_base_ori(rot)
+ else: # "single-arm-parallel" configuration setting
+ # Set up robots parallel to each other but offset from the center
+ for robot, offset in zip(self.robots, (-0.25, 0.25)):
+ xpos = robot.robot_model.base_xpos_offset["empty"]
+ xpos = np.array(xpos) + np.array((0, offset, 0))
+ robot.robot_model.set_base_xpos(xpos)
+
+ # Add arena and robot
+ mujoco_arena = EmptyArena()
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # Modify default agentview camera
+ mujoco_arena.set_camera(
+ camera_name="agentview",
+ pos=[1.0666432116509934, 1.4903257668114777e-08, 2.0563394967349096],
+ quat=[0.6530979871749878, 0.27104058861732483, 0.27104055881500244, 0.6530978679656982],
+ )
+
+ # initialize objects of interest
+ self.hole = PlateWithHoleObject(name="hole")
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "1 1",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ greenwood = CustomMaterial(
+ texture="WoodGreen",
+ tex_name="greenwood",
+ mat_name="greenwood_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.peg = CylinderObject(
+ name="peg",
+ size_min=(self.peg_radius[0], self.peg_length),
+ size_max=(self.peg_radius[1], self.peg_length),
+ material=greenwood,
+ rgba=[0, 1, 0, 1],
+ joints=None,
+ )
+
+ # Load hole object
+ hole_obj = self.hole.get_obj()
+ hole_obj.set("quat", "0 0 0.707 0.707")
+ hole_obj.set("pos", "0.11 0 0.17")
+
+ # Load peg object
+ peg_obj = self.peg.get_obj()
+ peg_obj.set("pos", array_to_string((0, 0, self.peg_length)))
+
+ # Append appropriate objects to arms
+ if self.env_configuration == "bimanual":
+ r_eef, l_eef = [self.robots[0].robot_model.eef_name[arm] for arm in self.robots[0].arms]
+ r_model, l_model = [self.robots[0].robot_model, self.robots[0].robot_model]
+ else:
+ r_eef, l_eef = [robot.robot_model.eef_name for robot in self.robots]
+ r_model, l_model = [self.robots[0].robot_model, self.robots[1].robot_model]
+ r_body = find_elements(root=r_model.worldbody, tags="body", attribs={"name": r_eef}, return_first=True)
+ l_body = find_elements(root=l_model.worldbody, tags="body", attribs={"name": l_eef}, return_first=True)
+ r_body.append(peg_obj)
+ l_body.append(hole_obj)
+
+ # task includes arena, robot, and objects of interest
+ # We don't add peg and hole directly since they were already appended to the robots
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ )
+
+ # Make sure to add relevant assets from peg and hole objects
+ self.model.merge_assets(self.hole)
+ self.model.merge_assets(self.peg)
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Additional object references from this env
+ self.hole_body_id = self.sim.model.body_name2id(self.hole.root_body)
+ self.peg_body_id = self.sim.model.body_name2id(self.peg.root_body)
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ if self.env_configuration == "bimanual":
+ pf0 = self.robots[0].robot_model.naming_prefix + "right_"
+ pf1 = self.robots[0].robot_model.naming_prefix + "left_"
+ else:
+ pf0 = self.robots[0].robot_model.naming_prefix
+ pf1 = self.robots[1].robot_model.naming_prefix
+ modality = "object"
+
+ # position and rotation of peg and hole
+ @sensor(modality=modality)
+ def hole_pos(obs_cache):
+ return np.array(self.sim.data.body_xpos[self.hole_body_id])
+
+ @sensor(modality=modality)
+ def hole_quat(obs_cache):
+ return T.convert_quat(self.sim.data.body_xquat[self.hole_body_id], to="xyzw")
+
+ @sensor(modality=modality)
+ def peg_to_hole(obs_cache):
+ return (
+ obs_cache["hole_pos"] - np.array(self.sim.data.body_xpos[self.peg_body_id])
+ if "hole_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ @sensor(modality=modality)
+ def peg_quat(obs_cache):
+ return T.convert_quat(self.sim.data.body_xquat[self.peg_body_id], to="xyzw")
+
+ # Relative orientation parameters
+ @sensor(modality=modality)
+ def angle(obs_cache):
+ t, d, cos = self._compute_orientation()
+ obs_cache["t"] = t
+ obs_cache["d"] = d
+ return cos
+
+ @sensor(modality=modality)
+ def t(obs_cache):
+ return obs_cache["t"] if "t" in obs_cache else 0.0
+
+ @sensor(modality=modality)
+ def d(obs_cache):
+ return obs_cache["d"] if "d" in obs_cache else 0.0
+
+ sensors = [hole_pos, hole_quat, peg_to_hole, peg_quat, angle, t, d]
+ names = [s.__name__ for s in sensors]
+
+ # Create observables
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ def _check_success(self):
+ """
+ Check if peg is successfully aligned and placed within the hole
+
+ Returns:
+ bool: True if peg is placed in hole correctly
+ """
+ t, d, cos = self._compute_orientation()
+
+ return d < 0.06 and -0.12 <= t <= 0.14 and cos > 0.95
+
+ def _compute_orientation(self):
+ """
+ Helper function to return the relative positions between the hole and the peg.
+ In particular, the intersection of the line defined by the peg and the plane
+ defined by the hole is computed; the parallel distance, perpendicular distance,
+ and angle are returned.
+
+ Returns:
+ 3-tuple:
+
+ - (float): parallel distance
+ - (float): perpendicular distance
+ - (float): angle
+ """
+ peg_mat = self.sim.data.body_xmat[self.peg_body_id]
+ peg_mat.shape = (3, 3)
+ peg_pos = self.sim.data.body_xpos[self.peg_body_id]
+
+ hole_pos = self.sim.data.body_xpos[self.hole_body_id]
+ hole_mat = self.sim.data.body_xmat[self.hole_body_id]
+ hole_mat.shape = (3, 3)
+
+ v = peg_mat @ np.array([0, 0, 1])
+ v = v / np.linalg.norm(v)
+ center = hole_pos + hole_mat @ np.array([0.1, 0, 0])
+
+ t = (center - peg_pos) @ v / (np.linalg.norm(v) ** 2)
+ d = np.linalg.norm(np.cross(v, peg_pos - center)) / np.linalg.norm(v)
+
+ hole_normal = hole_mat @ np.array([0, 0, 1])
+ return (
+ t,
+ d,
+ abs(np.dot(hole_normal, v) / np.linalg.norm(hole_normal) / np.linalg.norm(v)),
+ )
+
+ def _peg_pose_in_hole_frame(self):
+ """
+ A helper function that takes in a named data field and returns the pose of that
+ object in the base frame.
+
+ Returns:
+ np.array: (4,4) matrix corresponding to the pose of the peg in the hole frame
+ """
+ # World frame
+ peg_pos_in_world = self.sim.data.get_body_xpos(self.peg.root_body)
+ peg_rot_in_world = self.sim.data.get_body_xmat(self.peg.root_body).reshape((3, 3))
+ peg_pose_in_world = T.make_pose(peg_pos_in_world, peg_rot_in_world)
+
+ # World frame
+ hole_pos_in_world = self.sim.data.get_body_xpos(self.hole.root_body)
+ hole_rot_in_world = self.sim.data.get_body_xmat(self.hole.root_body).reshape((3, 3))
+ hole_pose_in_world = T.make_pose(hole_pos_in_world, hole_rot_in_world)
+
+ world_pose_in_hole = T.pose_inv(hole_pose_in_world)
+
+ peg_pose_in_hole = T.pose_in_A_to_pose_in_B(peg_pose_in_world, world_pose_in_hole)
+ return peg_pose_in_hole
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_transport.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_transport.py
new file mode 100644
index 0000000000000000000000000000000000000000..d989a30075944c4c815aca72450443a08dead386
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/two_arm_transport.py
@@ -0,0 +1,602 @@
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.environments.manipulation.two_arm_env import TwoArmEnv
+from robosuite.models.arenas import MultiTableArena
+from robosuite.models.objects import BoxObject, HammerObject, TransportGroup
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.mjcf_utils import CustomMaterial
+from robosuite.utils.observables import Observable, sensor
+from robosuite.utils.placement_samplers import SequentialCompositeSampler, UniformRandomSampler
+
+
+class TwoArmTransport(TwoArmEnv):
+ """
+ This class corresponds to the transport task for two robot arms, requiring a payload to be transported from an
+ initial bin into a target bin, while removing trash from the target bin to a trash bin.
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be either 2 single single-arm robots or 1 bimanual robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment. Can be either:
+
+ :`'bimanual'`: Only applicable for bimanual robot setups. Sets up the (single) bimanual robot on the -x
+ side of the table
+ :`'single-arm-parallel'`: Only applicable for multi single arm setups. Sets up the (two) single armed
+ robots next to each other on the -x side of the table
+ :`'single-arm-opposed'`: Only applicable for multi single arm setups. Sets up the (two) single armed
+ robots opposed from each others on the opposite +/-y sides of the table.
+
+ Note that "default" corresponds to either "bimanual" if a bimanual robot is used or "single-arm-opposed" if two
+ single-arm robots are used.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default grippers(s) associated
+ with the robot(s) the 'robots' specification. None removes the gripper, and any other (valid) model
+ overrides the default gripper. Should either be single str if same gripper type is to be used for all
+ robots or else it should be a list of the same length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ tables_boundary (3-tuple): x, y, and z dimensions of the table bounds. Two tables will be created at the edges of
+ this boundary
+
+ table_friction (3-tuple): the three mujoco friction parameters for
+ each table.
+
+ bin_size (3-tuple): (x,y,z) dimensions of bins to use
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ Raises:
+ ValueError: [Invalid number of robots specified]
+ ValueError: [Invalid env configuration]
+ ValueError: [Invalid robots for specified env configuration]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="default",
+ initialization_noise="default",
+ tables_boundary=(0.8, 1.2, 0.05),
+ table_friction=(1.0, 5e-3, 1e-4),
+ bin_size=(0.3, 0.3, 0.15),
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=False,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # settings for table top
+ self.tables_boundary = tables_boundary
+ self.table_full_size = np.array(tables_boundary)
+ self.table_full_size[1] *= 0.25 # each table size will only be a fraction of the full boundary
+ self.table_friction = table_friction
+ self.table_offsets = np.zeros((2, 3))
+ self.table_offsets[0, 1] = self.tables_boundary[1] * -3 / 8 # scale y offset
+ self.table_offsets[1, 1] = self.tables_boundary[1] * 3 / 8 # scale y offset
+ self.table_offsets[:, 2] = 0.8 # scale z offset
+ self.bin_size = np.array(bin_size)
+
+ # reward configuration
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+ self.height_threshold = 0.1 # threshold above the table surface which the payload is considered lifted
+
+ # whether to use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of 1.0 is provided when the payload is in the target bin and the trash is in the trash
+ bin
+
+ Un-normalized max-wise components if using reward shaping:
+
+ # TODO!
+
+ Note that the final reward is normalized and scaled by reward_scale / 1.0 as
+ well so that the max score is equal to reward_scale
+
+ Args:
+ action (np array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ # Initialize reward
+ reward = 0
+
+ # use a shaping reward if specified
+ if self.reward_shaping:
+ # TODO! So we print a warning and force sparse rewards
+ print(f"\n\nWarning! No dense reward current implemented for this task. Forcing sparse rewards\n\n")
+ self.reward_shaping = False
+
+ # Else this is the sparse reward setting
+ else:
+ # Provide reward if payload is in target bin and trash is in trash bin
+ if self._check_success():
+ reward = 1.0
+
+ if self.reward_scale is not None:
+ reward *= self.reward_scale / 1.0
+
+ return reward
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose(s) accordingly
+ if self.env_configuration == "bimanual":
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+ else:
+ if self.env_configuration == "single-arm-opposed":
+ # Set up robots facing towards each other by rotating them from their default position
+ for robot, rotation, offset in zip(self.robots, (np.pi / 2, -np.pi / 2), (-0.25, 0.25)):
+ xpos = robot.robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ rot = np.array((0, 0, rotation))
+ xpos = T.euler2mat(rot) @ np.array(xpos)
+ xpos += np.array((0, offset, 0))
+ robot.robot_model.set_base_xpos(xpos)
+ robot.robot_model.set_base_ori(rot)
+ else: # "single-arm-parallel" configuration setting
+ # Set up robots parallel to each other but offset from the center
+ for robot, offset in zip(self.robots, (-0.6, 0.6)):
+ xpos = robot.robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ xpos = np.array(xpos) + np.array((0, offset, 0))
+ robot.robot_model.set_base_xpos(xpos)
+
+ # load model for table top workspace
+ mujoco_arena = MultiTableArena(
+ table_offsets=self.table_offsets,
+ table_rots=0,
+ table_full_sizes=self.table_full_size,
+ table_frictions=self.table_friction,
+ has_legs=True,
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # Modify default agentview camera
+ mujoco_arena.set_camera(
+ camera_name="agentview",
+ pos=[0.8894354364730311, -3.481824231498976e-08, 1.7383813133506494],
+ quat=[0.6530981063842773, 0.2710406184196472, 0.27104079723358154, 0.6530979871749878],
+ )
+
+ # TODO: Add built-in method into TwoArmEnv so we have an elegant way of automatically adding extra cameras to all these envs
+ # Add shoulder cameras
+ mujoco_arena.set_camera(
+ camera_name="shouldercamera0",
+ pos=[0.4430096057365183, -1.0697399743660143, 1.3639950119362048],
+ quat=[0.804057240486145, 0.5531665086746216, 0.11286306381225586, 0.18644218146800995],
+ )
+ mujoco_arena.set_camera(
+ camera_name="shouldercamera1",
+ pos=[-0.40900713993039983, 0.9613722572245062, 1.3084072951772754],
+ quat=[0.15484197437763214, 0.12077208608388901, -0.5476858019828796, -0.8133130073547363],
+ )
+
+ # Add relevant materials
+ # Textures to use
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "3 3",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ redwood = CustomMaterial(
+ texture="WoodRed",
+ tex_name="redwood",
+ mat_name="redwood_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+
+ # initialize objects of interest
+ payload = HammerObject(
+ name="payload",
+ handle_radius=0.015,
+ handle_length=0.20,
+ handle_density=150.0,
+ handle_friction=4.0,
+ head_density_ratio=1.5,
+ )
+ trash = BoxObject(name="trash", size=[0.02, 0.02, 0.02], material=redwood)
+ self.transport = TransportGroup(
+ name="transport",
+ payload=payload,
+ trash=trash,
+ bin_size=self.bin_size,
+ )
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ mujoco_objects=list(self.transport.objects.values()),
+ )
+
+ # Create placement initializer
+ self._get_placement_initializer()
+
+ def _get_placement_initializer(self):
+ """
+ Helper function for defining placement initializer and object sampling bounds
+ """
+ # Create placement initializer
+ self.placement_initializer = SequentialCompositeSampler(name="ObjectSampler")
+
+ # Pre-define settings for each object's placement
+ object_names = ["start_bin", "lid", "payload", "target_bin", "trash", "trash_bin"]
+ table_nums = [0, 0, 0, 1, 1, 1]
+ x_centers = [
+ self.table_full_size[0] * 0.25,
+ 0, # gets overridden anyways
+ 0, # gets overridden anyways
+ -self.table_full_size[0] * 0.25,
+ 0, # gets overridden anyways
+ self.table_full_size[0] * 0.25,
+ ]
+ pos_tol = 0.005
+ rot_centers = [0, 0, np.pi / 2, 0, 0, 0]
+ rot_tols = [0, 0, np.pi / 6, 0, 0.3 * np.pi, 0]
+ rot_axes = ["z", "z", "y", "z", "z", "z"]
+ for obj_name, x, r, r_tol, r_axis, table_num in zip(
+ object_names, x_centers, rot_centers, rot_tols, rot_axes, table_nums
+ ):
+ # Get name and table
+ obj = self.transport.objects[obj_name]
+ table_pos = self.table_offsets[table_num]
+ # Create sampler for this object and add it to the sequential sampler
+ self.placement_initializer.append_sampler(
+ sampler=UniformRandomSampler(
+ name=f"{obj_name}ObjectSampler",
+ mujoco_objects=obj,
+ x_range=[x - pos_tol, x + pos_tol],
+ y_range=[-pos_tol, pos_tol],
+ rotation=[r - r_tol, r + r_tol],
+ rotation_axis=r_axis,
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=False,
+ reference_pos=table_pos,
+ z_offset=0.001,
+ )
+ )
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # low-level object information
+ if self.use_object_obs:
+ # Get robot prefix and define observables modality
+ if self.env_configuration == "bimanual":
+ pf0 = self.robots[0].robot_model.naming_prefix + "right_"
+ pf1 = self.robots[0].robot_model.naming_prefix + "left_"
+ else:
+ pf0 = self.robots[0].robot_model.naming_prefix
+ pf1 = self.robots[1].robot_model.naming_prefix
+ modality = "object"
+
+ # position and rotation of payload
+ @sensor(modality=modality)
+ def payload_pos(obs_cache):
+ return np.array(self.transport.payload_pos)
+
+ @sensor(modality=modality)
+ def payload_quat(obs_cache):
+ return np.array(self.transport.payload_quat)
+
+ # position and rotation of trash
+ @sensor(modality=modality)
+ def trash_pos(obs_cache):
+ return np.array(self.transport.trash_pos)
+
+ @sensor(modality=modality)
+ def trash_quat(obs_cache):
+ return np.array(self.transport.trash_quat)
+
+ # position and rotation of lid handle
+ @sensor(modality=modality)
+ def lid_handle_pos(obs_cache):
+ return np.array(self.transport.lid_handle_pos)
+
+ @sensor(modality=modality)
+ def lid_handle_quat(obs_cache):
+ return np.array(self.transport.lid_handle_quat)
+
+ # bin positions
+ @sensor(modality=modality)
+ def target_bin_pos(obs_cache):
+ return np.array(self.transport.target_bin_pos)
+
+ @sensor(modality=modality)
+ def trash_bin_pos(obs_cache):
+ return np.array(self.transport.trash_bin_pos)
+
+ # Relevant egocentric positions for arm0
+ @sensor(modality=modality)
+ def gripper0_to_payload(obs_cache):
+ return (
+ obs_cache["payload_pos"] - obs_cache[f"{pf0}eef_pos"]
+ if "payload_pos" in obs_cache and f"{pf0}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ @sensor(modality=modality)
+ def gripper0_to_lid_handle(obs_cache):
+ return (
+ obs_cache["lid_handle_pos"] - obs_cache[f"{pf0}eef_pos"]
+ if "lid_handle_pos" in obs_cache and f"{pf0}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ # Relevant egocentric positions for arm1
+ @sensor(modality=modality)
+ def gripper1_to_payload(obs_cache):
+ return (
+ obs_cache["payload_pos"] - obs_cache[f"{pf1}eef_pos"]
+ if "payload_pos" in obs_cache and f"{pf1}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ @sensor(modality=modality)
+ def gripper1_to_trash(obs_cache):
+ return (
+ obs_cache["trash_pos"] - obs_cache[f"{pf1}eef_pos"]
+ if "trash_pos" in obs_cache and f"{pf1}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ # Key boolean checks
+ @sensor(modality=modality)
+ def payload_in_target_bin(obs_cache):
+ return self.transport.payload_in_target_bin
+
+ @sensor(modality=modality)
+ def trash_in_trash_bin(obs_cache):
+ return self.transport.trash_in_trash_bin
+
+ sensors = [
+ payload_pos,
+ payload_quat,
+ trash_pos,
+ trash_quat,
+ lid_handle_pos,
+ lid_handle_quat,
+ target_bin_pos,
+ trash_bin_pos,
+ gripper0_to_payload,
+ gripper0_to_lid_handle,
+ gripper1_to_payload,
+ gripper1_to_trash,
+ payload_in_target_bin,
+ trash_in_trash_bin,
+ ]
+ names = [s.__name__ for s in sensors]
+
+ # Create observables
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ super()._reset_internal()
+
+ # Update sim
+ self.transport.update_sim(sim=self.sim)
+
+ # Reset all object positions using initializer sampler if we're not directly loading from an xml
+ if not self.deterministic_reset:
+
+ # Sample from the placement initializer for all objects
+ object_placements = self.placement_initializer.sample()
+
+ # Initialize placeholders that we'll need to override the payload, lid, and trash object locations
+ start_bin_pos = None
+ target_bin_pos = None
+
+ # Loop through all objects and reset their positions
+ for obj_pos, obj_quat, obj in object_placements.values():
+ # If this is toolbox or good bin, store their sampled positions
+ if "start_bin" in obj.name and "lid" not in obj.name:
+ start_bin_pos = obj_pos
+ elif "target_bin" in obj.name:
+ target_bin_pos = obj_pos
+ # Else if this is either the lid, payload, or trash object,
+ # we override their positions to match their respective containers' positions
+ elif "lid" in obj.name:
+ obj_pos = (start_bin_pos[0], start_bin_pos[1], obj_pos[2] + self.transport.bin_size[2])
+ elif "payload" in obj.name:
+ obj_pos = (
+ start_bin_pos[0],
+ start_bin_pos[1],
+ obj_pos[2] + self.transport.objects["start_bin"].wall_thickness,
+ )
+ elif "trash" in obj.name and "bin" not in obj.name:
+ obj_pos = (
+ target_bin_pos[0],
+ target_bin_pos[1],
+ obj_pos[2] + self.transport.objects["target_bin"].wall_thickness,
+ )
+ # Set the collision object joints
+ self.sim.data.set_joint_qpos(obj.joints[0], np.concatenate([np.array(obj_pos), np.array(obj_quat)]))
+
+ def _check_success(self):
+ """
+ Check if payload is in target in and trash is in trash bin
+
+ Returns:
+ bool: True if transport has been completed
+ """
+ return True if self.transport.payload_in_target_bin and self.transport.trash_in_trash_bin else False
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/wipe.py b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/wipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6af44132dce0d9b36f2159e7ed12683c54ddb77
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/manipulation/wipe.py
@@ -0,0 +1,768 @@
+import multiprocessing
+from collections import OrderedDict
+
+import numpy as np
+
+from robosuite.environments.manipulation.single_arm_env import SingleArmEnv
+from robosuite.models.arenas import WipeArena
+from robosuite.models.tasks import ManipulationTask
+from robosuite.utils.observables import Observable, sensor
+
+# Default Wipe environment configuration
+DEFAULT_WIPE_CONFIG = {
+ # settings for reward
+ "arm_limit_collision_penalty": -10.0, # penalty for reaching joint limit or arm collision (except the wiping tool) with the table
+ "wipe_contact_reward": 0.01, # reward for contacting something with the wiping tool
+ "unit_wiped_reward": 50.0, # reward per peg wiped
+ "ee_accel_penalty": 0, # penalty for large end-effector accelerations
+ "excess_force_penalty_mul": 0.05, # penalty for each step that the force is over the safety threshold
+ "distance_multiplier": 5.0, # multiplier for the dense reward inversely proportional to the mean location of the pegs to wipe
+ "distance_th_multiplier": 5.0, # multiplier in the tanh function for the aforementioned reward
+ # settings for table top
+ "table_full_size": [0.5, 0.8, 0.05], # Size of tabletop
+ "table_offset": [0.15, 0, 0.9], # Offset of table (z dimension defines max height of table)
+ "table_friction": [0.03, 0.005, 0.0001], # Friction parameters for the table
+ "table_friction_std": 0, # Standard deviation to sample different friction parameters for the table each episode
+ "table_height": 0.0, # Additional height of the table over the default location
+ "table_height_std": 0.0, # Standard deviation to sample different heigths of the table each episode
+ "line_width": 0.04, # Width of the line to wipe (diameter of the pegs)
+ "two_clusters": False, # if the dirt to wipe is one continuous line or two
+ "coverage_factor": 0.6, # how much of the table surface we cover
+ "num_markers": 100, # How many particles of dirt to generate in the environment
+ # settings for thresholds
+ "contact_threshold": 1.0, # Minimum eef force to qualify as contact [N]
+ "pressure_threshold": 0.5, # force threshold (N) to overcome to get increased contact wiping reward
+ "pressure_threshold_max": 60.0, # maximum force allowed (N)
+ # misc settings
+ "print_results": False, # Whether to print results or not
+ "get_info": False, # Whether to grab info after each env step if not
+ "use_robot_obs": True, # if we use robot observations (proprioception) as input to the policy
+ "use_contact_obs": True, # if we use a binary observation for whether robot is in contact or not
+ "early_terminations": True, # Whether we allow for early terminations or not
+ "use_condensed_obj_obs": True, # Whether to use condensed object observation representation (only applicable if obj obs is active)
+}
+
+
+class Wipe(SingleArmEnv):
+ """
+ This class corresponds to the Wiping task for a single robot arm
+
+ Args:
+ robots (str or list of str): Specification for specific robot arm(s) to be instantiated within this env
+ (e.g: "Sawyer" would generate one arm; ["Panda", "Panda", "Sawyer"] would generate three robot arms)
+ Note: Must be a single single-arm robot!
+
+ env_configuration (str): Specifies how to position the robots within the environment (default is "default").
+ For most single arm environments, this argument has no impact on the robot setup.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ gripper_types (str or list of str): type of gripper, used to instantiate
+ gripper models from gripper factory.
+ For this environment, setting a value other than the default ("WipingGripper") will raise an
+ AssertionError, as this environment is not meant to be used with any other alternative gripper.
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ use_object_obs (bool): if True, include object (cube) information in
+ the observation.
+
+ reward_scale (None or float): Scales the normalized reward function by the amount specified.
+ If None, environment reward remains unnormalized
+
+ reward_shaping (bool): if True, use dense rewards.
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ task_config (None or dict): Specifies the parameters relevant to this task. For a full list of expected
+ parameters, see the default configuration dict at the top of this file.
+ If None is specified, the default configuration will be used.
+
+ Raises:
+ AssertionError: [Gripper specified]
+ AssertionError: [Bad reward specification]
+ AssertionError: [Invalid number of robots specified]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ controller_configs=None,
+ gripper_types="WipingGripper",
+ initialization_noise="default",
+ use_camera_obs=True,
+ use_object_obs=True,
+ reward_scale=1.0,
+ reward_shaping=True,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None, # {None, instance, class, element}
+ task_config=None,
+ renderer="mujoco",
+ renderer_config=None,
+ ):
+ # Assert that the gripper type is None
+ assert (
+ gripper_types == "WipingGripper"
+ ), "Tried to specify gripper other than WipingGripper in Wipe environment!"
+
+ # Get config
+ self.task_config = task_config if task_config is not None else DEFAULT_WIPE_CONFIG
+
+ # Set task-specific parameters
+
+ # settings for the reward
+ self.reward_scale = reward_scale
+ self.reward_shaping = reward_shaping
+ self.arm_limit_collision_penalty = self.task_config["arm_limit_collision_penalty"]
+ self.wipe_contact_reward = self.task_config["wipe_contact_reward"]
+ self.unit_wiped_reward = self.task_config["unit_wiped_reward"]
+ self.ee_accel_penalty = self.task_config["ee_accel_penalty"]
+ self.excess_force_penalty_mul = self.task_config["excess_force_penalty_mul"]
+ self.distance_multiplier = self.task_config["distance_multiplier"]
+ self.distance_th_multiplier = self.task_config["distance_th_multiplier"]
+ # Final reward computation
+ # So that is better to finish that to stay touching the table for 100 steps
+ # The 0.5 comes from continuous_distance_reward at 0. If something changes, this may change as well
+ self.task_complete_reward = self.unit_wiped_reward * (self.wipe_contact_reward + 0.5)
+ # Verify that the distance multiplier is not greater than the task complete reward
+ assert (
+ self.task_complete_reward > self.distance_multiplier
+ ), "Distance multiplier cannot be greater than task complete reward!"
+
+ # settings for table top
+ self.table_full_size = self.task_config["table_full_size"]
+ self.table_height = self.task_config["table_height"]
+ self.table_height_std = self.task_config["table_height_std"]
+ delta_height = min(0, np.random.normal(self.table_height, self.table_height_std)) # sample variation in height
+ self.table_offset = np.array(self.task_config["table_offset"]) + np.array((0, 0, delta_height))
+ self.table_friction = self.task_config["table_friction"]
+ self.table_friction_std = self.task_config["table_friction_std"]
+ self.line_width = self.task_config["line_width"]
+ self.two_clusters = self.task_config["two_clusters"]
+ self.coverage_factor = self.task_config["coverage_factor"]
+ self.num_markers = self.task_config["num_markers"]
+
+ # settings for thresholds
+ self.contact_threshold = self.task_config["contact_threshold"]
+ self.pressure_threshold = self.task_config["pressure_threshold"]
+ self.pressure_threshold_max = self.task_config["pressure_threshold_max"]
+
+ # misc settings
+ self.print_results = self.task_config["print_results"]
+ self.get_info = self.task_config["get_info"]
+ self.use_robot_obs = self.task_config["use_robot_obs"]
+ self.use_contact_obs = self.task_config["use_contact_obs"]
+ self.early_terminations = self.task_config["early_terminations"]
+ self.use_condensed_obj_obs = self.task_config["use_condensed_obj_obs"]
+
+ # Scale reward if desired (see reward method for details)
+ self.reward_normalization_factor = horizon / (
+ self.num_markers * self.unit_wiped_reward + horizon * (self.wipe_contact_reward + self.task_complete_reward)
+ )
+
+ # ee resets
+ self.ee_force_bias = np.zeros(3)
+ self.ee_torque_bias = np.zeros(3)
+
+ # set other wipe-specific attributes
+ self.wiped_markers = []
+ self.collisions = 0
+ self.f_excess = 0
+ self.metadata = []
+ self.spec = "spec"
+
+ # whether to include and use ground-truth object states
+ self.use_object_obs = use_object_obs
+
+ super().__init__(
+ robots=robots,
+ env_configuration=env_configuration,
+ controller_configs=controller_configs,
+ mount_types="default",
+ gripper_types=gripper_types,
+ initialization_noise=initialization_noise,
+ use_camera_obs=use_camera_obs,
+ has_renderer=has_renderer,
+ has_offscreen_renderer=has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ camera_names=camera_names,
+ camera_heights=camera_heights,
+ camera_widths=camera_widths,
+ camera_depths=camera_depths,
+ camera_segmentations=camera_segmentations,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def reward(self, action=None):
+ """
+ Reward function for the task.
+
+ Sparse un-normalized reward:
+
+ - a discrete reward of self.unit_wiped_reward is provided per single dirt (peg) wiped during this step
+ - a discrete reward of self.task_complete_reward is provided if all dirt is wiped
+
+ Note that if the arm is either colliding or near its joint limit, a reward of 0 will be automatically given
+
+ Un-normalized summed components if using reward shaping (individual components can be set to 0:
+
+ - Reaching: in [0, self.distance_multiplier], proportional to distance between wiper and centroid of dirt
+ and zero if the table has been fully wiped clean of all the dirt
+ - Table Contact: in {0, self.wipe_contact_reward}, non-zero if wiper is in contact with table
+ - Wiping: in {0, self.unit_wiped_reward}, non-zero for each dirt (peg) wiped during this step
+ - Cleaned: in {0, self.task_complete_reward}, non-zero if no dirt remains on the table
+ - Collision / Joint Limit Penalty: in {self.arm_limit_collision_penalty, 0}, nonzero if robot arm
+ is colliding with an object
+ - Note that if this value is nonzero, no other reward components can be added
+ - Large Force Penalty: in [-inf, 0], scaled by wiper force and directly proportional to
+ self.excess_force_penalty_mul if the current force exceeds self.pressure_threshold_max
+ - Large Acceleration Penalty: in [-inf, 0], scaled by estimated wiper acceleration and directly
+ proportional to self.ee_accel_penalty
+
+ Note that the final per-step reward is normalized given the theoretical best episode return and then scaled:
+ reward_scale * (horizon /
+ (num_markers * unit_wiped_reward + horizon * (wipe_contact_reward + task_complete_reward)))
+
+ Args:
+ action (np array): [NOT USED]
+
+ Returns:
+ float: reward value
+ """
+ reward = 0
+
+ total_force_ee = np.linalg.norm(np.array(self.robots[0].recent_ee_forcetorques.current[:3]))
+
+ # Neg Reward from collisions of the arm with the table
+ if self.check_contact(self.robots[0].robot_model):
+ if self.reward_shaping:
+ reward = self.arm_limit_collision_penalty
+ self.collisions += 1
+ elif self.robots[0].check_q_limits():
+ if self.reward_shaping:
+ reward = self.arm_limit_collision_penalty
+ self.collisions += 1
+ else:
+ # If the arm is not colliding or in joint limits, we check if we are wiping
+ # (we don't want to reward wiping if there are unsafe situations)
+ active_markers = []
+
+ # Current 3D location of the corners of the wiping tool in world frame
+ c_geoms = self.robots[0].gripper.important_geoms["corners"]
+ corner1_id = self.sim.model.geom_name2id(c_geoms[0])
+ corner1_pos = np.array(self.sim.data.geom_xpos[corner1_id])
+ corner2_id = self.sim.model.geom_name2id(c_geoms[1])
+ corner2_pos = np.array(self.sim.data.geom_xpos[corner2_id])
+ corner3_id = self.sim.model.geom_name2id(c_geoms[2])
+ corner3_pos = np.array(self.sim.data.geom_xpos[corner3_id])
+ corner4_id = self.sim.model.geom_name2id(c_geoms[3])
+ corner4_pos = np.array(self.sim.data.geom_xpos[corner4_id])
+
+ # Unit vectors on my plane
+ v1 = corner1_pos - corner2_pos
+ v1 /= np.linalg.norm(v1)
+ v2 = corner4_pos - corner2_pos
+ v2 /= np.linalg.norm(v2)
+
+ # Corners of the tool in the coordinate frame of the plane
+ t1 = np.array([np.dot(corner1_pos - corner2_pos, v1), np.dot(corner1_pos - corner2_pos, v2)])
+ t2 = np.array([np.dot(corner2_pos - corner2_pos, v1), np.dot(corner2_pos - corner2_pos, v2)])
+ t3 = np.array([np.dot(corner3_pos - corner2_pos, v1), np.dot(corner3_pos - corner2_pos, v2)])
+ t4 = np.array([np.dot(corner4_pos - corner2_pos, v1), np.dot(corner4_pos - corner2_pos, v2)])
+
+ pp = [t1, t2, t4, t3]
+
+ # Normal of the plane defined by v1 and v2
+ n = np.cross(v1, v2)
+ n /= np.linalg.norm(n)
+
+ def isLeft(P0, P1, P2):
+ return (P1[0] - P0[0]) * (P2[1] - P0[1]) - (P2[0] - P0[0]) * (P1[1] - P0[1])
+
+ def PointInRectangle(X, Y, Z, W, P):
+ return isLeft(X, Y, P) < 0 and isLeft(Y, Z, P) < 0 and isLeft(Z, W, P) < 0 and isLeft(W, X, P) < 0
+
+ # Only go into this computation if there are contact points
+ if self.sim.data.ncon != 0:
+
+ # Check each marker that is still active
+ for marker in self.model.mujoco_arena.markers:
+
+ # Current marker 3D location in world frame
+ marker_pos = np.array(self.sim.data.body_xpos[self.sim.model.body_name2id(marker.root_body)])
+
+ # We use the second tool corner as point on the plane and define the vector connecting
+ # the marker position to that point
+ v = marker_pos - corner2_pos
+
+ # Shortest distance between the center of the marker and the plane
+ dist = np.dot(v, n)
+
+ # Projection of the center of the marker onto the plane
+ projected_point = np.array(marker_pos) - dist * n
+
+ # Positive distances means the center of the marker is over the plane
+ # The plane is aligned with the bottom of the wiper and pointing up, so the marker would be over it
+ if dist > 0.0:
+ # Distance smaller than this threshold means we are close to the plane on the upper part
+ if dist < 0.02:
+ # Write touching points and projected point in coordinates of the plane
+ pp_2 = np.array(
+ [np.dot(projected_point - corner2_pos, v1), np.dot(projected_point - corner2_pos, v2)]
+ )
+ # Check if marker is within the tool center:
+ if PointInRectangle(pp[0], pp[1], pp[2], pp[3], pp_2):
+ active_markers.append(marker)
+
+ # Obtain the list of currently active (wiped) markers that where not wiped before
+ # These are the markers we are wiping at this step
+ lall = np.where(np.isin(active_markers, self.wiped_markers, invert=True))
+ new_active_markers = np.array(active_markers)[lall]
+
+ # Loop through all new markers we are wiping at this step
+ for new_active_marker in new_active_markers:
+ # Grab relevant marker id info
+ new_active_marker_geom_id = self.sim.model.geom_name2id(new_active_marker.visual_geoms[0])
+ # Make this marker transparent since we wiped it (alpha = 0)
+ self.sim.model.geom_rgba[new_active_marker_geom_id][3] = 0
+ # Add this marker the wiped list
+ self.wiped_markers.append(new_active_marker)
+ # Add reward if we're using the dense reward
+ if self.reward_shaping:
+ reward += self.unit_wiped_reward
+
+ # Additional reward components if using dense rewards
+ if self.reward_shaping:
+ # If we haven't wiped all the markers yet, add a smooth reward for getting closer
+ # to the centroid of the dirt to wipe
+ if len(self.wiped_markers) < self.num_markers:
+ _, _, mean_pos_to_things_to_wipe = self._get_wipe_information()
+ mean_distance_to_things_to_wipe = np.linalg.norm(mean_pos_to_things_to_wipe)
+ reward += self.distance_multiplier * (
+ 1 - np.tanh(self.distance_th_multiplier * mean_distance_to_things_to_wipe)
+ )
+
+ # Reward for keeping contact
+ if self.sim.data.ncon != 0 and self._has_gripper_contact:
+ reward += self.wipe_contact_reward
+
+ # Penalty for excessive force with the end-effector
+ if total_force_ee > self.pressure_threshold_max:
+ reward -= self.excess_force_penalty_mul * total_force_ee
+ self.f_excess += 1
+
+ # Reward for pressing into table
+ # TODO: Need to include this computation somehow in the scaled reward computation
+ elif total_force_ee > self.pressure_threshold and self.sim.data.ncon > 1:
+ reward += self.wipe_contact_reward + 0.01 * total_force_ee
+ if self.sim.data.ncon > 50:
+ reward += 10.0 * self.wipe_contact_reward
+
+ # Penalize large accelerations
+ reward -= self.ee_accel_penalty * np.mean(abs(self.robots[0].recent_ee_acc.current))
+
+ # Final reward if all wiped
+ if len(self.wiped_markers) == self.num_markers:
+ reward += self.task_complete_reward
+
+ # Printing results
+ if self.print_results:
+ string_to_print = (
+ "Process {pid}, timestep {ts:>4}: reward: {rw:8.4f}"
+ "wiped markers: {ws:>3} collisions: {sc:>3} f-excess: {fe:>3}".format(
+ pid=id(multiprocessing.current_process()),
+ ts=self.timestep,
+ rw=reward,
+ ws=len(self.wiped_markers),
+ sc=self.collisions,
+ fe=self.f_excess,
+ )
+ )
+ print(string_to_print)
+
+ # If we're scaling our reward, we normalize the per-step rewards given the theoretical best episode return
+ # This is equivalent to scaling the reward by:
+ # reward_scale * (horizon /
+ # (num_markers * unit_wiped_reward + horizon * (wipe_contact_reward + task_complete_reward)))
+ if self.reward_scale:
+ reward *= self.reward_scale * self.reward_normalization_factor
+ return reward
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Adjust base pose accordingly
+ xpos = self.robots[0].robot_model.base_xpos_offset["table"](self.table_full_size[0])
+ self.robots[0].robot_model.set_base_xpos(xpos)
+
+ # Get robot's contact geoms
+ self.robot_contact_geoms = self.robots[0].robot_model.contact_geoms
+
+ mujoco_arena = WipeArena(
+ table_full_size=self.table_full_size,
+ table_friction=self.table_friction,
+ table_offset=self.table_offset,
+ table_friction_std=self.table_friction_std,
+ coverage_factor=self.coverage_factor,
+ num_markers=self.num_markers,
+ line_width=self.line_width,
+ two_clusters=self.two_clusters,
+ )
+
+ # Arena always gets set to zero origin
+ mujoco_arena.set_origin([0, 0, 0])
+
+ # task includes arena, robot, and objects of interest
+ self.model = ManipulationTask(
+ mujoco_arena=mujoco_arena,
+ mujoco_robots=[robot.robot_model for robot in self.robots],
+ )
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Creates object-based observables if enabled
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+
+ # Get prefix from robot model to avoid naming clashes for multiple robots
+ pf = self.robots[0].robot_model.naming_prefix
+ modality = "object"
+
+ sensors = []
+ names = []
+
+ # Add binary contact observation
+ if self.use_contact_obs:
+
+ @sensor(modality=f"{pf}proprio")
+ def gripper_contact(obs_cache):
+ return self._has_gripper_contact
+
+ sensors.append(gripper_contact)
+ names.append(f"{pf}contact")
+
+ # object information in the observation
+ if self.use_object_obs:
+
+ if self.use_condensed_obj_obs:
+ # use implicit representation of wiping objects
+ @sensor(modality=modality)
+ def wipe_radius(obs_cache):
+ wipe_rad, wipe_cent, _ = self._get_wipe_information()
+ obs_cache["wipe_centroid"] = wipe_cent
+ return wipe_rad
+
+ @sensor(modality=modality)
+ def wipe_centroid(obs_cache):
+ return obs_cache["wipe_centroid"] if "wipe_centroid" in obs_cache else np.zeros(3)
+
+ @sensor(modality=modality)
+ def proportion_wiped(obs_cache):
+ return len(self.wiped_markers) / self.num_markers
+
+ sensors += [proportion_wiped, wipe_radius, wipe_centroid]
+ names += ["proportion_wiped", "wipe_radius", "wipe_centroid"]
+
+ if self.use_robot_obs:
+ # also use ego-centric obs
+ @sensor(modality=modality)
+ def gripper_to_wipe_centroid(obs_cache):
+ return (
+ obs_cache["wipe_centroid"] - obs_cache[f"{pf}eef_pos"]
+ if "wipe_centroid" in obs_cache and f"{pf}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ sensors.append(gripper_to_wipe_centroid)
+ names.append("gripper_to_wipe_centroid")
+
+ else:
+ # use explicit representation of wiping objects
+ for i, marker in enumerate(self.model.mujoco_arena.markers):
+ marker_sensors, marker_sensor_names = self._create_marker_sensors(i, marker, modality)
+ sensors += marker_sensors
+ names += marker_sensor_names
+
+ # Create observables
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _create_marker_sensors(self, i, marker, modality="object"):
+ """
+ Helper function to create sensors for a given marker. This is abstracted in a separate function call so that we
+ don't have local function naming collisions during the _setup_observables() call.
+
+ Args:
+ i (int): ID number corresponding to the marker
+ marker (MujocoObject): Marker to create sensors for
+ modality (str): Modality to assign to all sensors
+
+ Returns:
+ 2-tuple:
+ sensors (list): Array of sensors for the given marker
+ names (list): array of corresponding observable names
+ """
+ pf = self.robots[0].robot_model.naming_prefix
+
+ @sensor(modality=modality)
+ def marker_pos(obs_cache):
+ return np.array(self.sim.data.body_xpos[self.sim.model.body_name2id(marker.root_body)])
+
+ @sensor(modality=modality)
+ def marker_wiped(obs_cache):
+ return [0, 1][marker in self.wiped_markers]
+
+ sensors = [marker_pos, marker_wiped]
+ names = [f"marker{i}_pos", f"marker{i}_wiped"]
+
+ if self.use_robot_obs:
+ # also use ego-centric obs
+ @sensor(modality=modality)
+ def gripper_to_marker(obs_cache):
+ return (
+ obs_cache[f"marker{i}_pos"] - obs_cache[f"{pf}eef_pos"]
+ if f"marker{i}_pos" in obs_cache and f"{pf}eef_pos" in obs_cache
+ else np.zeros(3)
+ )
+
+ sensors.append(gripper_to_marker)
+ names.append(f"gripper_to_marker{i}")
+
+ return sensors, names
+
+ def _reset_internal(self):
+ super()._reset_internal()
+
+ # inherited class should reset positions of objects (only if we're not using a deterministic reset)
+ if not self.deterministic_reset:
+ self.model.mujoco_arena.reset_arena(self.sim)
+
+ # Reset all internal vars for this wipe task
+ self.timestep = 0
+ self.wiped_markers = []
+ self.collisions = 0
+ self.f_excess = 0
+
+ # ee resets - bias at initial state
+ self.ee_force_bias = np.zeros(3)
+ self.ee_torque_bias = np.zeros(3)
+
+ def _check_success(self):
+ """
+ Checks if Task succeeds (all dirt wiped).
+
+ Returns:
+ bool: True if completed task
+ """
+ return True if len(self.wiped_markers) == self.num_markers else False
+
+ def _check_terminated(self):
+ """
+ Check if the task has completed one way or another. The following conditions lead to termination:
+
+ - Collision
+ - Task completion (wiping succeeded)
+ - Joint Limit reached
+
+ Returns:
+ bool: True if episode is terminated
+ """
+
+ terminated = False
+
+ # Prematurely terminate if contacting the table with the arm
+ if self.check_contact(self.robots[0].robot_model):
+ if self.print_results:
+ print(40 * "-" + " COLLIDED " + 40 * "-")
+ terminated = True
+
+ # Prematurely terminate if task is success
+ if self._check_success():
+ if self.print_results:
+ print(40 * "+" + " FINISHED WIPING " + 40 * "+")
+ terminated = True
+
+ # Prematurely terminate if contacting the table with the arm
+ if self.robots[0].check_q_limits():
+ if self.print_results:
+ print(40 * "-" + " JOINT LIMIT " + 40 * "-")
+ terminated = True
+
+ return terminated
+
+ def _post_action(self, action):
+ """
+ In addition to super method, add additional info if requested
+
+ Args:
+ action (np.array): Action to execute within the environment
+
+ Returns:
+ 3-tuple:
+
+ - (float) reward from the environment
+ - (bool) whether the current episode is completed or not
+ - (dict) info about current env step
+ """
+ reward, done, info = super()._post_action(action)
+
+ # Update force bias
+ if np.linalg.norm(self.ee_force_bias) == 0:
+ self.ee_force_bias = self.robots[0].ee_force
+ self.ee_torque_bias = self.robots[0].ee_torque
+
+ if self.get_info:
+ info["add_vals"] = ["nwipedmarkers", "colls", "percent_viapoints_", "f_excess"]
+ info["nwipedmarkers"] = len(self.wiped_markers)
+ info["colls"] = self.collisions
+ info["percent_viapoints_"] = len(self.wiped_markers) / self.num_markers
+ info["f_excess"] = self.f_excess
+
+ # allow episode to finish early if allowed
+ if self.early_terminations:
+ done = done or self._check_terminated()
+
+ return reward, done, info
+
+ def _get_wipe_information(self):
+ """Returns set of wiping information"""
+ mean_pos_to_things_to_wipe = np.zeros(3)
+ wipe_centroid = np.zeros(3)
+ marker_positions = []
+ num_non_wiped_markers = 0
+ if len(self.wiped_markers) < self.num_markers:
+ for marker in self.model.mujoco_arena.markers:
+ if marker not in self.wiped_markers:
+ marker_pos = np.array(self.sim.data.body_xpos[self.sim.model.body_name2id(marker.root_body)])
+ wipe_centroid += marker_pos
+ marker_positions.append(marker_pos)
+ num_non_wiped_markers += 1
+ wipe_centroid /= max(1, num_non_wiped_markers)
+ mean_pos_to_things_to_wipe = wipe_centroid - self._eef_xpos
+ # Radius of circle from centroid capturing all remaining wiping markers
+ max_radius = 0
+ if num_non_wiped_markers > 0:
+ max_radius = np.max(np.linalg.norm(np.array(marker_positions) - wipe_centroid, axis=1))
+ # Return all values
+ return max_radius, wipe_centroid, mean_pos_to_things_to_wipe
+
+ @property
+ def _has_gripper_contact(self):
+ """
+ Determines whether the gripper is making contact with an object, as defined by the eef force surprassing
+ a certain threshold defined by self.contact_threshold
+
+ Returns:
+ bool: True if contact is surpasses given threshold magnitude
+ """
+ return np.linalg.norm(self.robots[0].ee_force - self.ee_force_bias) > self.contact_threshold
diff --git a/phantom/submodules/phantom-robosuite/robosuite/environments/robot_env.py b/phantom/submodules/phantom-robosuite/robosuite/environments/robot_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c836655a0c3f69ebfa402acdb496b475ac1e573
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/environments/robot_env.py
@@ -0,0 +1,619 @@
+from collections import OrderedDict
+from copy import deepcopy
+
+import numpy as np
+
+import robosuite.macros as macros
+from robosuite.controllers import reset_controllers
+from robosuite.environments.base import MujocoEnv
+from robosuite.robots import ROBOT_CLASS_MAPPING
+from robosuite.utils.mjcf_utils import IMAGE_CONVENTION_MAPPING
+from robosuite.utils.observables import Observable, sensor
+
+
+class RobotEnv(MujocoEnv):
+ """
+ Initializes a robot environment in Mujoco.
+
+ Args:
+ robots: Specification for specific robot(s) to be instantiated within this env
+
+ env_configuration (str): Specifies how to position the robot(s) within the environment. Default is "default",
+ which should be interpreted accordingly by any subclasses.
+
+ controller_configs (str or list of dict): If set, contains relevant controller parameters for creating a
+ custom controller. Else, uses the default controller for this specific task. Should either be single
+ dict if same controller is to be used for all robots or else it should be a list of the same length as
+ "robots" param
+
+ mount_types (None or str or list of str): type of mount, used to instantiate mount models from mount factory.
+ Default is "default", which is the default mount associated with the robot(s) the 'robots' specification.
+ None results in no mount, and any other (valid) model overrides the default mount. Should either be
+ single str if same mount type is to be used for all robots or else it should be a list of the same
+ length as "robots" param
+
+ initialization_noise (dict or list of dict): Dict containing the initialization noise parameters.
+ The expected keys and corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to `None` or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ Should either be single dict if same noise value is to be used for all robots or else it should be a
+ list of the same length as "robots" param
+
+ :Note: Specifying "default" will automatically use the default noise settings.
+ Specifying None will automatically create the required dict with "magnitude" set to 0.0.
+
+ use_camera_obs (bool): if True, every observation includes rendered image(s)
+
+ has_renderer (bool): If true, render the simulation state in
+ a viewer instead of headless mode.
+
+ has_offscreen_renderer (bool): True if using off-screen rendering
+
+ render_camera (str): Name of camera to render if `has_renderer` is True. Setting this value to 'None'
+ will result in the default angle being applied, which is useful as it can be dragged / panned by
+ the user using the mouse
+
+ render_collision_mesh (bool): True if rendering collision meshes in camera. False otherwise.
+
+ render_visual_mesh (bool): True if rendering visual meshes in camera. False otherwise.
+
+ render_gpu_device_id (int): corresponds to the GPU device id to use for offscreen rendering.
+ Defaults to -1, in which case the device will be inferred from environment variables
+ (GPUS or CUDA_VISIBLE_DEVICES).
+
+ control_freq (float): how many control signals to receive in every second. This sets the amount of
+ simulation time that passes between every action input.
+
+ horizon (int): Every episode lasts for exactly @horizon timesteps.
+
+ ignore_done (bool): True if never terminating the environment (ignore @horizon).
+
+ hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
+ only calls sim.reset and resets all robosuite-internal variables
+
+ camera_names (str or list of str): name of camera to be rendered. Should either be single str if
+ same name is to be used for all cameras' rendering or else it should be a list of cameras to render.
+
+ :Note: At least one camera must be specified if @use_camera_obs is True.
+
+ :Note: To render all robots' cameras of a certain type (e.g.: "robotview" or "eye_in_hand"), use the
+ convention "all-{name}" (e.g.: "all-robotview") to automatically render all camera images from each
+ robot's camera list).
+
+ camera_heights (int or list of int): height of camera frame. Should either be single int if
+ same height is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_widths (int or list of int): width of camera frame. Should either be single int if
+ same width is to be used for all cameras' frames or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_depths (bool or list of bool): True if rendering RGB-D, and RGB otherwise. Should either be single
+ bool if same depth setting is to be used for all cameras or else it should be a list of the same length as
+ "camera names" param.
+
+ camera_segmentations (None or str or list of str or list of list of str): Camera segmentation(s) to use
+ for each camera. Valid options are:
+
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ If not None, multiple types of segmentations can be specified. A [list of str / str or None] specifies
+ [multiple / a single] segmentation(s) to use for all cameras. A list of list of str specifies per-camera
+ segmentation setting(s) to use.
+
+ robot_configs (list of dict): Per-robot configurations set from any subclass initializers.
+
+ Raises:
+ ValueError: [Camera obs require offscreen renderer]
+ ValueError: [Camera name must be specified to use camera obs]
+ """
+
+ def __init__(
+ self,
+ robots,
+ env_configuration="default",
+ mount_types="default",
+ controller_configs=None,
+ initialization_noise=None,
+ use_camera_obs=True,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ render_camera="frontview",
+ render_collision_mesh=False,
+ render_visual_mesh=True,
+ render_gpu_device_id=-1,
+ control_freq=20,
+ horizon=1000,
+ ignore_done=False,
+ hard_reset=True,
+ camera_names="agentview",
+ camera_heights=256,
+ camera_widths=256,
+ camera_depths=False,
+ camera_segmentations=None,
+ robot_configs=None,
+ renderer="mujoco",
+ renderer_config=None,
+ direct_gripper_control=False,
+ ):
+ # First, verify that correct number of robots are being inputted
+ self.env_configuration = env_configuration
+ self._check_robot_configuration(robots)
+
+ # Robot
+ robots = list(robots) if type(robots) is list or type(robots) is tuple else [robots]
+ self.num_robots = len(robots)
+ self.robot_names = robots
+ self.robots = self._input2list(None, self.num_robots)
+ self._action_dim = None
+
+ # Mount
+ mount_types = self._input2list(mount_types, self.num_robots)
+
+ # Controller
+ controller_configs = self._input2list(controller_configs, self.num_robots)
+
+ # Initialization Noise
+ initialization_noise = self._input2list(initialization_noise, self.num_robots)
+
+ # Observations -- Ground truth = object_obs, Image data = camera_obs
+ self.use_camera_obs = use_camera_obs
+
+ # Camera / Rendering Settings
+ self.has_offscreen_renderer = has_offscreen_renderer
+ self.camera_names = (
+ list(camera_names) if type(camera_names) is list or type(camera_names) is tuple else [camera_names]
+ )
+ self.num_cameras = len(self.camera_names)
+
+ self.camera_heights = self._input2list(camera_heights, self.num_cameras)
+ self.camera_widths = self._input2list(camera_widths, self.num_cameras)
+ self.camera_depths = self._input2list(camera_depths, self.num_cameras)
+ self.camera_segmentations = self._input2list(camera_segmentations, self.num_cameras)
+ # We need to parse camera segmentations more carefully since it may be a nested list
+ seg_is_nested = False
+ for i, camera_s in enumerate(self.camera_segmentations):
+ if isinstance(camera_s, list) or isinstance(camera_s, tuple):
+ seg_is_nested = True
+ break
+ camera_segs = deepcopy(self.camera_segmentations)
+ for i, camera_s in enumerate(self.camera_segmentations):
+ if camera_s is not None:
+ self.camera_segmentations[i] = self._input2list(camera_s, 1) if seg_is_nested else deepcopy(camera_segs)
+
+ # sanity checks for camera rendering
+ if self.use_camera_obs and not self.has_offscreen_renderer:
+ raise ValueError("Error: Camera observations require an offscreen renderer!")
+ if self.use_camera_obs and self.camera_names is None:
+ raise ValueError("Must specify at least one camera name when using camera obs")
+
+ # Robot configurations -- update from subclass configs
+ if robot_configs is None:
+ robot_configs = [{} for _ in range(self.num_robots)]
+ self.robot_configs = [
+ dict(
+ **{
+ "controller_config": controller_configs[idx],
+ "mount_type": mount_types[idx],
+ "initialization_noise": initialization_noise[idx],
+ "control_freq": control_freq,
+ "direct_gripper_control": direct_gripper_control,
+ },
+ **robot_config,
+ )
+ for idx, robot_config in enumerate(robot_configs)
+ ]
+
+ # Run superclass init
+ super().__init__(
+ has_renderer=has_renderer,
+ has_offscreen_renderer=self.has_offscreen_renderer,
+ render_camera=render_camera,
+ render_collision_mesh=render_collision_mesh,
+ render_visual_mesh=render_visual_mesh,
+ render_gpu_device_id=render_gpu_device_id,
+ control_freq=control_freq,
+ horizon=horizon,
+ ignore_done=ignore_done,
+ hard_reset=hard_reset,
+ renderer=renderer,
+ renderer_config=renderer_config,
+ )
+
+ def visualize(self, vis_settings):
+ """
+ In addition to super call, visualizes robots.
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "robots" keyword as well as any other relevant
+ options specified.
+ """
+ # Run superclass method first
+ super().visualize(vis_settings=vis_settings)
+ # Loop over robots to visualize them independently
+ for robot in self.robots:
+ robot.visualize(vis_settings=vis_settings)
+
+ @property
+ def _visualizations(self):
+ """
+ Visualization keywords for this environment
+
+ Returns:
+ set: All components that can be individually visualized for this environment
+ """
+ vis_set = super()._visualizations
+ vis_set.add("robots")
+ return vis_set
+
+ @property
+ def action_spec(self):
+ """
+ Action space (low, high) for this environment
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum (low) action values
+ - (np.array) maximum (high) action values
+ """
+ low, high = [], []
+ for robot in self.robots:
+ lo, hi = robot.action_limits
+ low, high = np.concatenate([low, lo]), np.concatenate([high, hi])
+ return low, high
+
+ @property
+ def action_dim(self):
+ """
+ Size of the action space
+
+ Returns:
+ int: Action space dimension
+ """
+ return self._action_dim
+
+ @staticmethod
+ def _input2list(inp, length):
+ """
+ Helper function that converts an input that is either a single value or a list into a list
+
+ Args:
+ inp (None or str or list): Input value to be converted to list
+ length (int): Length of list to broadcast input to
+
+ Returns:
+ list: input @inp converted into a list of length @length
+ """
+ # convert to list if necessary
+ return list(inp) if type(inp) is list or type(inp) is tuple else [inp for _ in range(length)]
+
+ def _load_model(self):
+ """
+ Loads an xml model, puts it in self.model
+ """
+ super()._load_model()
+
+ # Load robots
+ self._load_robots()
+
+ def _setup_references(self):
+ """
+ Sets up references to important components. A reference is typically an
+ index or a list of indices that point to the corresponding elements
+ in a flatten array, which is how MuJoCo stores physical simulation data.
+ """
+ super()._setup_references()
+
+ # Setup robot-specific references as well (note: requires resetting of sim for robot first)
+ for robot in self.robots:
+ robot.reset_sim(self.sim)
+ robot.setup_references()
+
+ def _setup_observables(self):
+ """
+ Sets up observables to be used for this environment. Loops through all robots and grabs their corresponding
+ observables to add to the procedurally generated dict of observables
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ observables = super()._setup_observables()
+ # Loop through all robots and grab their observables, adding it to the proprioception modality
+ for robot in self.robots:
+ robot_obs = robot.setup_observables()
+ observables.update(robot_obs)
+
+ # Loop through cameras and update the observations if using camera obs
+ if self.use_camera_obs:
+ # Create sensor information
+ sensors = []
+ names = []
+ for (cam_name, cam_w, cam_h, cam_d, cam_segs) in zip(
+ self.camera_names,
+ self.camera_widths,
+ self.camera_heights,
+ self.camera_depths,
+ self.camera_segmentations,
+ ):
+
+ # Add cameras associated to our arrays
+ cam_sensors, cam_sensor_names = self._create_camera_sensors(
+ cam_name, cam_w=cam_w, cam_h=cam_h, cam_d=cam_d, cam_segs=cam_segs, modality="image"
+ )
+ sensors += cam_sensors
+ names += cam_sensor_names
+
+ # If any the camera segmentations are not None, then we shrink all the sites as a hacky way to
+ # prevent them from being rendered in the segmentation mask
+ if not all(seg is None for seg in self.camera_segmentations):
+ self.sim.model.site_size[:, :] = 1.0e-8
+
+ # Create observables for these cameras
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _create_camera_sensors(self, cam_name, cam_w, cam_h, cam_d, cam_segs, modality="image"):
+ """
+ Helper function to create sensors for a given camera. This is abstracted in a separate function call so that we
+ don't have local function naming collisions during the _setup_observables() call.
+ Args:
+ cam_name (str): Name of camera to create sensors for
+ cam_w (int): Width of camera
+ cam_h (int): Height of camera
+ cam_d (bool): Whether to create a depth sensor as well
+ cam_segs (None or list): Type of segmentation(s) to use, where each entry can be the following:
+ `None`: no segmentation sensor used
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+
+ modality (str): Modality to assign to all sensors
+ Returns:
+ 2-tuple:
+ sensors (list): Array of sensors for the given camera
+ names (list): array of corresponding observable names
+ """
+ # Make sure we get correct convention
+ convention = IMAGE_CONVENTION_MAPPING[macros.IMAGE_CONVENTION]
+
+ # Create sensor information
+ sensors = []
+ names = []
+
+ # Add camera observables to the dict
+ rgb_sensor_name = f"{cam_name}_image"
+ depth_sensor_name = f"{cam_name}_depth"
+ segmentation_sensor_name = f"{cam_name}_segmentation"
+
+ @sensor(modality=modality)
+ def camera_rgb(obs_cache):
+ img = self.sim.render(
+ camera_name=cam_name,
+ width=cam_w,
+ height=cam_h,
+ depth=cam_d,
+ )
+ if cam_d:
+ rgb, depth = img
+ obs_cache[depth_sensor_name] = np.expand_dims(depth[::convention], axis=-1)
+ return rgb[::convention]
+ else:
+ return img[::convention]
+
+ sensors.append(camera_rgb)
+ names.append(rgb_sensor_name)
+
+ if cam_d:
+
+ @sensor(modality=modality)
+ def camera_depth(obs_cache):
+ return obs_cache[depth_sensor_name] if depth_sensor_name in obs_cache else np.zeros((cam_h, cam_w, 1))
+
+ sensors.append(camera_depth)
+ names.append(depth_sensor_name)
+
+ if cam_segs is not None:
+ # Define mapping we'll use for segmentation
+ for cam_s in cam_segs:
+ seg_sensor, seg_sensor_name = self._create_segementation_sensor(
+ cam_name=cam_name,
+ cam_w=cam_w,
+ cam_h=cam_h,
+ cam_s=cam_s,
+ seg_name_root=segmentation_sensor_name,
+ modality=modality,
+ )
+
+ sensors.append(seg_sensor)
+ names.append(seg_sensor_name)
+
+ return sensors, names
+
+ def _create_segementation_sensor(self, cam_name, cam_w, cam_h, cam_s, seg_name_root, modality="image"):
+ """
+ Helper function to create sensors for a given camera. This is abstracted in a separate function call so that we
+ don't have local function naming collisions during the _setup_observables() call.
+
+ Args:
+ cam_name (str): Name of camera to create sensors for
+ cam_w (int): Width of camera
+ cam_h (int): Height of camera
+ cam_s (None or list): Type of segmentation to use, should be the following:
+ `'instance'`: segmentation at the class-instance level
+ `'class'`: segmentation at the class level
+ `'element'`: segmentation at the per-geom level
+ seg_name_root (str): Sensor name root to assign to this sensor
+
+ modality (str): Modality to assign to all sensors
+
+ Returns:
+ 2-tuple:
+ camera_segmentation (function): Generated sensor function for this segmentation sensor
+ name (str): Corresponding sensor name
+ """
+ # Make sure we get correct convention
+ convention = IMAGE_CONVENTION_MAPPING[macros.IMAGE_CONVENTION]
+
+ if cam_s == "instance":
+ name2id = {inst: i for i, inst in enumerate(list(self.model.instances_to_ids.keys()))}
+ mapping = {idn: name2id[inst] for idn, inst in self.model.geom_ids_to_instances.items()}
+ elif cam_s == "class":
+ name2id = {cls: i for i, cls in enumerate(list(self.model.classes_to_ids.keys()))}
+ mapping = {idn: name2id[cls] for idn, cls in self.model.geom_ids_to_classes.items()}
+ else: # element
+ # No additional mapping needed
+ mapping = None
+
+ @sensor(modality=modality)
+ def camera_segmentation(obs_cache):
+ seg = self.sim.render(
+ camera_name=cam_name,
+ width=cam_w,
+ height=cam_h,
+ depth=False,
+ segmentation=True,
+ )
+ seg = np.expand_dims(seg[::convention, :, 1], axis=-1)
+ # Map raw IDs to grouped IDs if we're using instance or class-level segmentation
+ if mapping is not None:
+ seg = (
+ np.fromiter(map(lambda x: mapping.get(x, -1), seg.flatten()), dtype=np.int32).reshape(
+ cam_h, cam_w, 1
+ )
+ + 1
+ )
+ return seg
+
+ name = f"{seg_name_root}_{cam_s}"
+
+ return camera_segmentation, name
+
+ def _reset_internal(self):
+ """
+ Resets simulation internal configurations.
+ """
+ # Run superclass reset functionality
+ super()._reset_internal()
+
+ # Reset controllers
+ reset_controllers()
+
+ # Reset action dim
+ self._action_dim = 0
+
+ # Reset robot and update action space dimension along the way
+ for robot in self.robots:
+ robot.reset(deterministic=self.deterministic_reset)
+ self._action_dim += robot.action_dim
+
+ # Update cameras if appropriate
+ if self.use_camera_obs:
+ temp_names = []
+ for cam_name in self.camera_names:
+ if "all-" in cam_name:
+ # We need to add all robot-specific camera names that include the key after the tag "all-"
+ start_idx = len(temp_names) - 1
+ key = cam_name.replace("all-", "")
+ for robot in self.robots:
+ for robot_cam_name in robot.robot_model.cameras:
+ if key in robot_cam_name:
+ temp_names.append(robot_cam_name)
+ # We also need to broadcast the corresponding values from each camera dimensions as well
+ end_idx = len(temp_names) - 1
+ self.camera_widths = (
+ self.camera_widths[:start_idx]
+ + [self.camera_widths[start_idx]] * (end_idx - start_idx)
+ + self.camera_widths[(start_idx + 1) :]
+ )
+ self.camera_heights = (
+ self.camera_heights[:start_idx]
+ + [self.camera_heights[start_idx]] * (end_idx - start_idx)
+ + self.camera_heights[(start_idx + 1) :]
+ )
+ self.camera_depths = (
+ self.camera_depths[:start_idx]
+ + [self.camera_depths[start_idx]] * (end_idx - start_idx)
+ + self.camera_depths[(start_idx + 1) :]
+ )
+ else:
+ # We simply add this camera to the temp_names
+ temp_names.append(cam_name)
+ # Lastly, replace camera names with the updated ones
+ self.camera_names = temp_names
+
+ def _pre_action(self, action, policy_step=False):
+ """
+ Overrides the superclass method to control the robot(s) within this enviornment using their respective
+ controllers using the passed actions and gripper control.
+
+ Args:
+ action (np.array): The control to apply to the robot(s). Note that this should be a flat 1D array that
+ encompasses all actions to be distributed to each robot if there are multiple. For each section of the
+ action space assigned to a single robot, the first @self.robots[i].controller.control_dim dimensions
+ should be the desired controller actions and if the robot has a gripper, the next
+ @self.robots[i].gripper.dof dimensions should be actuation controls for the gripper.
+ policy_step (bool): Whether a new policy step (action) is being taken
+
+ Raises:
+ AssertionError: [Invalid action dimension]
+ """
+ # Verify that the action is the correct dimension
+ assert len(action) == self.action_dim, "environment got invalid action dimension -- expected {}, got {}".format(
+ self.action_dim, len(action)
+ )
+
+ # Update robot joints based on controller actions
+ cutoff = 0
+ for idx, robot in enumerate(self.robots):
+ robot_action = action[cutoff : cutoff + robot.action_dim]
+ robot.control(robot_action, policy_step=policy_step)
+ cutoff += robot.action_dim
+
+ def _load_robots(self):
+ """
+ Instantiates robots and stores them within the self.robots attribute
+ """
+ # Loop through robots and instantiate Robot object for each
+ for idx, (name, config) in enumerate(zip(self.robot_names, self.robot_configs)):
+ # Create the robot instance
+ self.robots[idx] = ROBOT_CLASS_MAPPING[name](robot_type=name, idn=idx, **config)
+ # Now, load the robot models
+ self.robots[idx].load_model()
+
+ def reward(self, action):
+ """
+ Runs superclass method by default
+ """
+ return super().reward(action)
+
+ def _check_success(self):
+ """
+ Runs superclass method by default
+ """
+ return super()._check_success()
+
+ def _check_robot_configuration(self, robots):
+ """
+ Sanity check to make sure inputted robots and the corresponding requested task/configuration combo is legal.
+ Should be implemented in every specific task module
+
+ Args:
+ robots (str or list of str): Inputted requested robots at the task-level environment
+ """
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/macros.py b/phantom/submodules/phantom-robosuite/robosuite/macros.py
new file mode 100644
index 0000000000000000000000000000000000000000..918c51590448a05557479f3552834f3079ff35c1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/macros.py
@@ -0,0 +1,55 @@
+"""
+Macro settings that can be imported and toggled. Internally, specific parts of the codebase rely on these settings
+for determining core functionality.
+
+To make sure global reference is maintained, should import these settings as:
+
+`import robosuite.macros as macros`
+"""
+
+# Global Mujoco Simulation Parameters
+SIMULATION_TIMESTEP = 0.002 # Internal simulation timestep (in seconds)
+
+# Instance Randomization
+# Used if we want to randomize geom groups uniformly per instance -- e.g.: entire robot arm, vs. per-joint geom
+# This should get set to True in your script BEFORE an environment is created or the DR wrapper is used
+USING_INSTANCE_RANDOMIZATION = False
+
+# Numba settings
+# TODO: Numba causes BSOD for NutAssembly task when rendering offscreen (deterministically!)
+ENABLE_NUMBA = True
+CACHE_NUMBA = True
+
+# Image Convention
+# Robosuite (Mujoco)-rendered images are based on the OpenGL coordinate frame convention, whereas many downstream
+# applications assume an OpenCV coordinate frame convention. For consistency, you can set the image convention
+# here; this will assure that any rendered frames will match the associated convention.
+# See the figure at the bottom of https://amytabb.com/ts/2019_06_28/ for an informative overview.
+IMAGE_CONVENTION = "opencv" # Options are {"opengl", "opencv"}
+
+# Image concatenation
+# In general, observations are concatenated together by modality. However, image observations are expensive memory-wise,
+# so we skip concatenating all images together by default, unless this flag is set to True
+CONCATENATE_IMAGES = False
+
+MUJOCO_GPU_RENDERING = True
+
+# Spacemouse settings. Used by SpaceMouse class in robosuite/devices/spacemouse.py
+SPACEMOUSE_VENDOR_ID = 9583
+SPACEMOUSE_PRODUCT_ID = 50734
+
+# If LOGGING LEVEL is set to None, the logger will be turned off
+CONSOLE_LOGGING_LEVEL = "WARN"
+# File logging is written to /tmp/robosuite_{time}_{pid}.log by default
+FILE_LOGGING_LEVEL = None
+
+# Override with macros from macros_private.py file, if it exists
+try:
+ from robosuite.macros_private import *
+except ImportError:
+ import robosuite
+ from robosuite.utils.log_utils import ROBOSUITE_DEFAULT_LOGGER
+
+ ROBOSUITE_DEFAULT_LOGGER.warn("No private macro file found!")
+ ROBOSUITE_DEFAULT_LOGGER.warn("It is recommended to use a private macro file")
+ ROBOSUITE_DEFAULT_LOGGER.warn("To setup, run: python {}/scripts/setup_macros.py".format(robosuite.__path__[0]))
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73dc77f94297a4ff36e4db5e3e360f691189f6f8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/__init__.py
@@ -0,0 +1,4 @@
+import os
+from .world import MujocoWorldBase
+
+assets_root = os.path.join(os.path.dirname(__file__), "assets")
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab448ace61fb1cb6d2947202d009c7a5601e4096
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/__init__.py
@@ -0,0 +1,9 @@
+from .arena import Arena
+from .table_arena import TableArena
+from .table_arena2 import TableArena2
+from .phantom_table_arena import PhantomTableArena
+from .multi_table_arena import MultiTableArena
+from .pegs_arena import PegsArena
+from .bins_arena import BinsArena
+from .empty_arena import EmptyArena
+from .wipe_arena import WipeArena
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/arena.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/arena.py
new file mode 100644
index 0000000000000000000000000000000000000000..d274f9f5e56e7b5f2fe5ef6f9b6bab4312530e1a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/arena.py
@@ -0,0 +1,81 @@
+import numpy as np
+
+from robosuite.models.base import MujocoXML
+from robosuite.utils.mjcf_utils import (
+ ENVIRONMENT_COLLISION_COLOR,
+ array_to_string,
+ find_elements,
+ new_body,
+ new_element,
+ new_geom,
+ new_joint,
+ recolor_collision_geoms,
+ string_to_array,
+)
+
+
+class Arena(MujocoXML):
+ """Base arena class."""
+
+ def __init__(self, fname):
+ super().__init__(fname)
+ # Get references to floor and bottom
+ self.bottom_pos = np.zeros(3)
+ self.floor = self.worldbody.find("./geom[@name='floor']")
+
+ # Run any necessary post-processing on the model
+ self._postprocess_arena()
+
+ # Recolor all geoms
+ recolor_collision_geoms(
+ root=self.worldbody,
+ rgba=ENVIRONMENT_COLLISION_COLOR,
+ exclude=lambda e: True if e.get("name", None) == "floor" else False,
+ )
+
+ def set_origin(self, offset):
+ """
+ Applies a constant offset to all objects.
+
+ Args:
+ offset (3-tuple): (x,y,z) offset to apply to all nodes in this XML
+ """
+ offset = np.array(offset)
+ for node in self.worldbody.findall("./*[@pos]"):
+ cur_pos = string_to_array(node.get("pos"))
+ new_pos = cur_pos + offset
+ node.set("pos", array_to_string(new_pos))
+
+ def set_camera(self, camera_name, pos, quat, camera_attribs=None):
+ """
+ Sets a camera with @camera_name. If the camera already exists, then this overwrites its pos and quat values.
+
+ Args:
+ camera_name (str): Camera name to search for / create
+ pos (3-array): (x,y,z) coordinates of camera in world frame
+ quat (4-array): (w,x,y,z) quaternion of camera in world frame
+ camera_attribs (dict): If specified, should be additional keyword-mapped attributes for this camera.
+ See http://www.mujoco.org/book/XMLreference.html#camera for exact attribute specifications.
+ """
+ # Determine if camera already exists
+ camera = find_elements(root=self.worldbody, tags="camera", attribs={"name": camera_name}, return_first=True)
+
+ # Compose attributes
+ if camera_attribs is None:
+ camera_attribs = {}
+ camera_attribs["pos"] = array_to_string(pos)
+ camera_attribs["quat"] = array_to_string(quat)
+
+ if camera is None:
+ # If camera doesn't exist, then add a new camera with the specified attributes
+ self.worldbody.append(new_element(tag="camera", name=camera_name, **camera_attribs))
+ else:
+ # Otherwise, we edit all specified attributes in that camera
+ for attrib, value in camera_attribs.items():
+ camera.set(attrib, value)
+
+ def _postprocess_arena(self):
+ """
+ Runs any necessary post-processing on the imported Arena model
+ """
+ pass
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/bins_arena.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/bins_arena.py
new file mode 100644
index 0000000000000000000000000000000000000000..b50b40bcdd298a3398607fb284bf0a48962343f7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/bins_arena.py
@@ -0,0 +1,34 @@
+import numpy as np
+
+from robosuite.models.arenas import Arena
+from robosuite.utils.mjcf_utils import array_to_string, xml_path_completion
+
+
+class BinsArena(Arena):
+ """
+ Workspace that contains two bins placed side by side.
+
+ Args:
+ bin1_pos (3-tuple): (x,y,z) position to place bin1
+ table_full_size (3-tuple): (L,W,H) full dimensions of the table
+ table_friction (3-tuple): (sliding, torsional, rolling) friction parameters of the table
+ """
+
+ def __init__(
+ self, bin1_pos=(0.1, -0.5, 0.8), table_full_size=(0.39, 0.49, 0.82), table_friction=(1, 0.005, 0.0001)
+ ):
+ super().__init__(xml_path_completion("arenas/bins_arena.xml"))
+
+ self.table_full_size = np.array(table_full_size)
+ self.table_half_size = self.table_full_size / 2
+ self.table_friction = table_friction
+
+ self.bin1_body = self.worldbody.find("./body[@name='bin1']")
+ self.bin2_body = self.worldbody.find("./body[@name='bin2']")
+ self.table_top_abs = np.array(bin1_pos)
+
+ self.configure_location()
+
+ def configure_location(self):
+ """Configures correct locations for this arena"""
+ self.floor.set("pos", array_to_string(self.bottom_pos))
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/empty_arena.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/empty_arena.py
new file mode 100644
index 0000000000000000000000000000000000000000..e10da831b24b0c1870ffb640327cdf543efd02ab
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/empty_arena.py
@@ -0,0 +1,9 @@
+from robosuite.models.arenas import Arena
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class EmptyArena(Arena):
+ """Empty workspace."""
+
+ def __init__(self):
+ super().__init__(xml_path_completion("arenas/empty_arena.xml"))
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/multi_table_arena.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/multi_table_arena.py
new file mode 100644
index 0000000000000000000000000000000000000000..f62ef357e2c45a483ac2643cdea084a7b9380722
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/multi_table_arena.py
@@ -0,0 +1,149 @@
+from collections.abc import Iterable
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.models.arenas import Arena
+from robosuite.utils.mjcf_utils import (
+ array_to_string,
+ new_body,
+ new_geom,
+ new_site,
+ string_to_array,
+ xml_path_completion,
+)
+
+
+class MultiTableArena(Arena):
+ """
+ Workspace that contains multiple tables.
+ Args:
+ table_offsets (list of 3-array): (x,y,z) offset from center of arena when placing each table.
+ Note that the number of tables is inferred from the length of this list
+ Note that the z value sets the upper limit of the table
+ table_rots (float or list of float): z-rotation to apply to each table. If only a
+ single value is given, it will be broadcasted according to the total number of tables
+ table_full_sizes (3-array or list of 3-array): (L,W,H) full dimensions of each table. If only a
+ single value is given, it will be broadcasted according to the total number of tables
+ table_frictions (3-array or list of 3-array): (sliding, torsional, rolling) friction parameters of each table.
+ has_legs (bool or list of bool): whether each table has legs or not. If only a
+ single value is given, it will be broadcasted according to the total number of tables
+ xml (str): xml file to load arena
+ """
+
+ def __init__(
+ self,
+ table_offsets,
+ table_rots=0,
+ table_full_sizes=(0.8, 0.8, 0.05),
+ table_frictions=(1, 0.005, 0.0001),
+ has_legs=True,
+ xml="arenas/multi_table_arena.xml",
+ ):
+ # Set internal vars
+ self.table_offsets = np.array(table_offsets)
+ self.n_tables = self.table_offsets.shape[0]
+ self.table_rots = (
+ np.array(table_rots) if isinstance(table_rots, Iterable) else np.ones(self.n_tables) * table_rots
+ )
+ self.table_full_sizes = np.array(table_full_sizes)
+ if len(self.table_full_sizes.shape) == 1:
+ self.table_full_sizes = np.stack([self.table_full_sizes] * self.n_tables, axis=0)
+ self.table_half_sizes = self.table_full_sizes / 2
+ self.table_frictions = np.array(table_frictions)
+ if len(self.table_frictions.shape) == 1:
+ self.table_frictions = np.stack([self.table_frictions] * self.n_tables, axis=0)
+ self.center_pos = np.array(self.table_offsets)
+ self.center_pos[:, 2] -= self.table_half_sizes[:, 2]
+ self.has_legs = has_legs if isinstance(has_legs, Iterable) else [has_legs] * self.n_tables
+
+ # Run super init
+ super().__init__(xml_path_completion(xml))
+
+ # Configure any relevant locations
+ self.configure_location()
+
+ def _add_table(self, name, offset, rot, half_size, friction, has_legs):
+ """
+ Procedurally generates a table and adds it to the XML
+ """
+ # Create body for this table, and add it to worldbody
+ table_body = new_body(name=name, pos=offset - np.array([0, 0, half_size[2]]))
+ self.worldbody.append(table_body)
+
+ # Create core attributes for table geoms
+ table_attribs = {
+ "pos": (0, 0, 0),
+ "quat": T.convert_quat(T.axisangle2quat([0, 0, rot]), to="wxyz"),
+ "size": half_size,
+ "type": "box",
+ }
+
+ # Create collision and visual bodies, and add them to the table body
+ col_geom = new_geom(name=f"{name}_collision", group=0, friction=friction, **table_attribs)
+ vis_geom = new_geom(
+ name=f"{name}_visual", group=1, conaffinity=0, contype=0, material="table_ceramic", **table_attribs
+ )
+ table_body.append(col_geom)
+ table_body.append(vis_geom)
+
+ # Add tabletop site to table
+ top_site = new_site(name=f"{name}_top", pos=(0, 0, half_size[2]), size=(0.001, 0.001, 0.001), rgba=(0, 0, 0, 0))
+ table_body.append(top_site)
+
+ # Add legs if requested
+ if has_legs:
+ delta_x = [0.1, -0.1, -0.1, 0.1]
+ delta_y = [0.1, 0.1, -0.1, -0.1]
+ for i, (dx, dy) in enumerate(zip(delta_x, delta_y)):
+ # If x-length of table is less than a certain length, place leg in the middle between ends
+ # Otherwise we place it near the edge
+ x = 0
+ if half_size[0] > abs(dx * 2.0):
+ x += np.sign(dx) * half_size[0] - dx
+ # Repeat the same process for y
+ y = 0
+ if half_size[1] > abs(dy * 2.0):
+ y += np.sign(dy) * half_size[1] - dy
+ # Rotate x and y values according to requested rotation
+ c, s = np.cos(rot), np.sin(rot)
+ rot_xy = np.array([[c, -s], [s, c]]) @ np.array([x, y])
+ # Add in offsets
+ x = rot_xy[0]
+ y = rot_xy[1]
+ # Get z value
+ z = (offset[2] - half_size[2]) / 2.0
+ # Create visual geom and add it to table body
+ leg_geom = new_geom(
+ name=f"{name}_leg{i}_visual",
+ pos=(x, y, -z),
+ type="cylinder",
+ size=(0.025, z),
+ group=1,
+ conaffinity=0,
+ contype=0,
+ material="table_legs_metal",
+ )
+ table_body.append(leg_geom)
+
+ def configure_location(self):
+ """Configures correct locations for this arena"""
+ # Set floor correctly
+ self.floor.set("pos", array_to_string(self.bottom_pos))
+
+ def _postprocess_arena(self):
+ """
+ Runs any necessary post-processing on the imported Arena model
+ """
+ # Create tables
+ for i, (offset, rot, half_size, friction, legs) in enumerate(
+ zip(self.table_offsets, self.table_rots, self.table_half_sizes, self.table_frictions, self.has_legs)
+ ):
+ self._add_table(
+ name=f"table{i}",
+ offset=offset,
+ rot=rot,
+ half_size=half_size,
+ friction=friction,
+ has_legs=legs,
+ )
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/pegs_arena.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/pegs_arena.py
new file mode 100644
index 0000000000000000000000000000000000000000..492b99d331038c2a791d111e43d797030ec1cac9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/pegs_arena.py
@@ -0,0 +1,30 @@
+from robosuite.models.arenas import TableArena
+
+
+class PegsArena(TableArena):
+ """
+ Workspace that contains a tabletop with two fixed pegs.
+
+ Args:
+ table_full_size (3-tuple): (L,W,H) full dimensions of the table
+ table_friction (3-tuple): (sliding, torsional, rolling) friction parameters of the table
+ table_offset (3-tuple): (x,y,z) offset from center of arena when placing table.
+ Note that the z value sets the upper limit of the table
+ """
+
+ def __init__(
+ self,
+ table_full_size=(0.45, 0.69, 0.05),
+ table_friction=(1, 0.005, 0.0001),
+ table_offset=(0, 0, 0),
+ ):
+ super().__init__(
+ table_full_size=table_full_size,
+ table_friction=table_friction,
+ table_offset=table_offset,
+ xml="arenas/pegs_arena.xml",
+ )
+
+ # Get references to peg bodies
+ self.peg1_body = self.worldbody.find("./body[@name='peg1']")
+ self.peg2_body = self.worldbody.find("./body[@name='peg2']")
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/phantom_table_arena.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/phantom_table_arena.py
new file mode 100644
index 0000000000000000000000000000000000000000..abca7d38409eb8511489beb1e59e8da857e28be2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/phantom_table_arena.py
@@ -0,0 +1,97 @@
+import numpy as np
+
+from robosuite.models.arenas import Arena
+from robosuite.utils.mjcf_utils import array_to_string, string_to_array, xml_path_completion
+
+
+class PhantomTableArena(Arena):
+ """
+ Workspace that contains an empty table.
+
+
+ Args:
+ table_full_size (3-tuple): (L,W,H) full dimensions of the table
+ table_friction (3-tuple): (sliding, torsional, rolling) friction parameters of the table
+ table_offset (3-tuple): (x,y,z) offset from center of arena when placing table.
+ Note that the z value sets the upper limit of the table
+ has_legs (bool): whether the table has legs or not
+ xml (str): xml file to load arena
+ """
+
+ def __init__(
+ self,
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1, 0.005, 0.0001),
+ table_offset=(0, 0, 0.8),
+ has_legs=True,
+ xml="arenas/phantom_table_arena.xml",
+ ):
+ super().__init__(xml_path_completion(xml))
+
+ self.table_full_size = np.array(table_full_size)
+ self.table_half_size = self.table_full_size / 2
+ self.table_friction = table_friction
+ self.table_offset = table_offset
+ self.center_pos = self.bottom_pos + np.array([0, 0, -self.table_half_size[2]]) + self.table_offset
+
+ self.table_body = self.worldbody.find("./body[@name='table']")
+ self.table_collision = self.table_body.find("./geom[@name='table_collision']")
+ self.table_visual = self.table_body.find("./geom[@name='table_visual']")
+ self.table_top = self.table_body.find("./site[@name='table_top']")
+
+ self.has_legs = has_legs
+ self.table_legs_visual = [
+ self.table_body.find("./geom[@name='table_leg1_visual']"),
+ self.table_body.find("./geom[@name='table_leg2_visual']"),
+ self.table_body.find("./geom[@name='table_leg3_visual']"),
+ self.table_body.find("./geom[@name='table_leg4_visual']"),
+ ]
+
+ self.configure_location()
+
+ def configure_location(self):
+ """Configures correct locations for this arena"""
+ self.floor.set("pos", array_to_string(self.bottom_pos))
+
+ self.table_body.set("pos", array_to_string(self.center_pos))
+ self.table_collision.set("size", array_to_string(self.table_half_size))
+ self.table_collision.set("friction", array_to_string(self.table_friction))
+ self.table_visual.set("size", array_to_string(self.table_half_size))
+
+ self.table_top.set("pos", array_to_string(np.array([0, 0, self.table_half_size[2]])))
+
+ # If we're not using legs, set their size to 0
+ if not self.has_legs:
+ for leg in self.table_legs_visual:
+ leg.set("rgba", array_to_string([1, 0, 0, 0]))
+ leg.set("size", array_to_string([0.0001, 0.0001]))
+ else:
+ # Otherwise, set leg locations appropriately
+ delta_x = [0.1, -0.1, -0.1, 0.1]
+ delta_y = [0.1, 0.1, -0.1, -0.1]
+ for leg, dx, dy in zip(self.table_legs_visual, delta_x, delta_y):
+ # If x-length of table is less than a certain length, place leg in the middle between ends
+ # Otherwise we place it near the edge
+ x = 0
+ if self.table_half_size[0] > abs(dx * 2.0):
+ x += np.sign(dx) * self.table_half_size[0] - dx
+ # Repeat the same process for y
+ y = 0
+ if self.table_half_size[1] > abs(dy * 2.0):
+ y += np.sign(dy) * self.table_half_size[1] - dy
+ # Get z value
+ z = (self.table_offset[2] - self.table_half_size[2]) / 2.0
+ # Set leg position
+ leg.set("pos", array_to_string([x, y, -z]))
+ # Set leg size
+ leg.set("size", array_to_string([0.025, z]))
+
+ @property
+ def table_top_abs(self):
+ """
+ Grabs the absolute position of table top
+
+ Returns:
+ np.array: (x,y,z) table position
+ """
+ return string_to_array(self.floor.get("pos")) + self.table_offset
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/table_arena.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/table_arena.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a14c11c95d612ec6fd34eea62623d28a2460220
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/table_arena.py
@@ -0,0 +1,97 @@
+import numpy as np
+
+from robosuite.models.arenas import Arena
+from robosuite.utils.mjcf_utils import array_to_string, string_to_array, xml_path_completion
+
+
+class TableArena(Arena):
+ """
+ Workspace that contains an empty table.
+
+
+ Args:
+ table_full_size (3-tuple): (L,W,H) full dimensions of the table
+ table_friction (3-tuple): (sliding, torsional, rolling) friction parameters of the table
+ table_offset (3-tuple): (x,y,z) offset from center of arena when placing table.
+ Note that the z value sets the upper limit of the table
+ has_legs (bool): whether the table has legs or not
+ xml (str): xml file to load arena
+ """
+
+ def __init__(
+ self,
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1, 0.005, 0.0001),
+ table_offset=(0, 0, 0.8),
+ has_legs=True,
+ xml="arenas/table_arena.xml",
+ ):
+ super().__init__(xml_path_completion(xml))
+
+ self.table_full_size = np.array(table_full_size)
+ self.table_half_size = self.table_full_size / 2
+ self.table_friction = table_friction
+ self.table_offset = table_offset
+ self.center_pos = self.bottom_pos + np.array([0, 0, -self.table_half_size[2]]) + self.table_offset
+
+ self.table_body = self.worldbody.find("./body[@name='table']")
+ self.table_collision = self.table_body.find("./geom[@name='table_collision']")
+ self.table_visual = self.table_body.find("./geom[@name='table_visual']")
+ self.table_top = self.table_body.find("./site[@name='table_top']")
+
+ self.has_legs = has_legs
+ self.table_legs_visual = [
+ self.table_body.find("./geom[@name='table_leg1_visual']"),
+ self.table_body.find("./geom[@name='table_leg2_visual']"),
+ self.table_body.find("./geom[@name='table_leg3_visual']"),
+ self.table_body.find("./geom[@name='table_leg4_visual']"),
+ ]
+
+ self.configure_location()
+
+ def configure_location(self):
+ """Configures correct locations for this arena"""
+ self.floor.set("pos", array_to_string(self.bottom_pos))
+
+ self.table_body.set("pos", array_to_string(self.center_pos))
+ self.table_collision.set("size", array_to_string(self.table_half_size))
+ self.table_collision.set("friction", array_to_string(self.table_friction))
+ self.table_visual.set("size", array_to_string(self.table_half_size))
+
+ self.table_top.set("pos", array_to_string(np.array([0, 0, self.table_half_size[2]])))
+
+ # If we're not using legs, set their size to 0
+ if not self.has_legs:
+ for leg in self.table_legs_visual:
+ leg.set("rgba", array_to_string([1, 0, 0, 0]))
+ leg.set("size", array_to_string([0.0001, 0.0001]))
+ else:
+ # Otherwise, set leg locations appropriately
+ delta_x = [0.1, -0.1, -0.1, 0.1]
+ delta_y = [0.1, 0.1, -0.1, -0.1]
+ for leg, dx, dy in zip(self.table_legs_visual, delta_x, delta_y):
+ # If x-length of table is less than a certain length, place leg in the middle between ends
+ # Otherwise we place it near the edge
+ x = 0
+ if self.table_half_size[0] > abs(dx * 2.0):
+ x += np.sign(dx) * self.table_half_size[0] - dx
+ # Repeat the same process for y
+ y = 0
+ if self.table_half_size[1] > abs(dy * 2.0):
+ y += np.sign(dy) * self.table_half_size[1] - dy
+ # Get z value
+ z = (self.table_offset[2] - self.table_half_size[2]) / 2.0
+ # Set leg position
+ leg.set("pos", array_to_string([x, y, -z]))
+ # Set leg size
+ leg.set("size", array_to_string([0.025, z]))
+
+ @property
+ def table_top_abs(self):
+ """
+ Grabs the absolute position of table top
+
+ Returns:
+ np.array: (x,y,z) table position
+ """
+ return string_to_array(self.floor.get("pos")) + self.table_offset
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/table_arena2.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/table_arena2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5813888d0157926b0ee58cb94a8186bce174eb3d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/table_arena2.py
@@ -0,0 +1,98 @@
+
+import numpy as np
+
+from robosuite.models.arenas import Arena
+from robosuite.utils.mjcf_utils import array_to_string, string_to_array, xml_path_completion
+
+
+class TableArena2(Arena):
+ """
+ Workspace that contains an empty table.
+
+
+ Args:
+ table_full_size (3-tuple): (L,W,H) full dimensions of the table
+ table_friction (3-tuple): (sliding, torsional, rolling) friction parameters of the table
+ table_offset (3-tuple): (x,y,z) offset from center of arena when placing table.
+ Note that the z value sets the upper limit of the table
+ has_legs (bool): whether the table has legs or not
+ xml (str): xml file to load arena
+ """
+
+ def __init__(
+ self,
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(1, 0.005, 0.0001),
+ table_offset=(0, 0, 0.8),
+ has_legs=True,
+ xml="arenas/table_arena2.xml",
+ ):
+ super().__init__(xml_path_completion(xml))
+
+ self.table_full_size = np.array(table_full_size)
+ self.table_half_size = self.table_full_size / 2
+ self.table_friction = table_friction
+ self.table_offset = table_offset
+ self.center_pos = self.bottom_pos + np.array([0, 0, -self.table_half_size[2]]) + self.table_offset
+
+ self.table_body = self.worldbody.find("./body[@name='table']")
+ self.table_collision = self.table_body.find("./geom[@name='table_collision']")
+ self.table_visual = self.table_body.find("./geom[@name='table_visual']")
+ self.table_top = self.table_body.find("./site[@name='table_top']")
+
+ self.has_legs = has_legs
+ self.table_legs_visual = [
+ self.table_body.find("./geom[@name='table_leg1_visual']"),
+ self.table_body.find("./geom[@name='table_leg2_visual']"),
+ self.table_body.find("./geom[@name='table_leg3_visual']"),
+ self.table_body.find("./geom[@name='table_leg4_visual']"),
+ ]
+
+ self.configure_location()
+
+ def configure_location(self):
+ """Configures correct locations for this arena"""
+ self.floor.set("pos", array_to_string(self.bottom_pos))
+
+ self.table_body.set("pos", array_to_string(self.center_pos))
+ self.table_collision.set("size", array_to_string(self.table_half_size))
+ self.table_collision.set("friction", array_to_string(self.table_friction))
+ self.table_visual.set("size", array_to_string(self.table_half_size))
+
+ self.table_top.set("pos", array_to_string(np.array([0, 0, self.table_half_size[2]])))
+
+ # If we're not using legs, set their size to 0
+ if not self.has_legs:
+ for leg in self.table_legs_visual:
+ leg.set("rgba", array_to_string([1, 0, 0, 0]))
+ leg.set("size", array_to_string([0.0001, 0.0001]))
+ else:
+ # Otherwise, set leg locations appropriately
+ delta_x = [0.1, -0.1, -0.1, 0.1]
+ delta_y = [0.1, 0.1, -0.1, -0.1]
+ for leg, dx, dy in zip(self.table_legs_visual, delta_x, delta_y):
+ # If x-length of table is less than a certain length, place leg in the middle between ends
+ # Otherwise we place it near the edge
+ x = 0
+ if self.table_half_size[0] > abs(dx * 2.0):
+ x += np.sign(dx) * self.table_half_size[0] - dx
+ # Repeat the same process for y
+ y = 0
+ if self.table_half_size[1] > abs(dy * 2.0):
+ y += np.sign(dy) * self.table_half_size[1] - dy
+ # Get z value
+ z = (self.table_offset[2] - self.table_half_size[2]) / 2.0
+ # Set leg position
+ leg.set("pos", array_to_string([x, y, -z]))
+ # Set leg size
+ leg.set("size", array_to_string([0.025, z]))
+
+ @property
+ def table_top_abs(self):
+ """
+ Grabs the absolute position of table top
+
+ Returns:
+ np.array: (x,y,z) table position
+ """
+ return string_to_array(self.floor.get("pos")) + self.table_offset
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/arenas/wipe_arena.py b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/wipe_arena.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1db1cad4eb0501dda1db869eb86871dc74981a3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/arenas/wipe_arena.py
@@ -0,0 +1,186 @@
+import numpy as np
+
+from robosuite.models.arenas import TableArena
+from robosuite.models.objects import CylinderObject
+from robosuite.utils.mjcf_utils import CustomMaterial, find_elements
+
+
+class WipeArena(TableArena):
+ """
+ Workspace that contains an empty table with visual markers on its surface.
+
+ Args:
+ table_full_size (3-tuple): (L,W,H) full dimensions of the table
+ table_friction (3-tuple): (sliding, torsional, rolling) friction parameters of the table
+ table_offset (3-tuple): (x,y,z) offset from center of arena when placing table.
+ Note that the z value sets the upper limit of the table
+ coverage_factor (float): Fraction of table that will be sampled for dirt placement
+ num_markers (int): Number of dirt (peg) particles to generate in a path on the table
+ table_friction_std (float): Standard deviation to sample for the peg friction
+ line_width (float): Diameter of dirt path trace
+ two_clusters (bool): If set, will generate two separate dirt paths with half the number of sensors in each
+ """
+
+ def __init__(
+ self,
+ table_full_size=(0.8, 0.8, 0.05),
+ table_friction=(0.01, 0.005, 0.0001),
+ table_offset=(0, 0, 0.8),
+ coverage_factor=0.9,
+ num_markers=10,
+ table_friction_std=0,
+ line_width=0.02,
+ two_clusters=False,
+ ):
+ # Tactile table-specific features
+ self.table_friction_std = table_friction_std
+ self.line_width = line_width
+ self.markers = []
+ self.coverage_factor = coverage_factor
+ self.num_markers = num_markers
+ self.two_clusters = two_clusters
+
+ # Attribute to hold current direction of sampled dirt path
+ self.direction = None
+
+ # run superclass init
+ super().__init__(
+ table_full_size=table_full_size,
+ table_friction=table_friction,
+ table_offset=table_offset,
+ )
+
+ def configure_location(self):
+ """Configures correct locations for this arena"""
+ # Run superclass first
+ super().configure_location()
+
+ # Define start position for drawing the line
+ pos = self.sample_start_pos()
+
+ # Define dirt material for markers
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "1 1",
+ "specular": "0.0",
+ "shininess": "0.0",
+ }
+ dirt = CustomMaterial(
+ texture="Dirt",
+ tex_name="dirt",
+ mat_name="dirt_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ shared=True,
+ )
+
+ # Define line(s) drawn on table
+ for i in range(self.num_markers):
+ # If we're using two clusters, we resample the starting position and direction at the halfway point
+ if self.two_clusters and i == int(np.floor(self.num_markers / 2)):
+ pos = self.sample_start_pos()
+ marker_name = f"contact{i}"
+ marker = CylinderObject(
+ name=marker_name,
+ size=[self.line_width / 2, 0.001],
+ rgba=[1, 1, 1, 1],
+ material=dirt,
+ obj_type="visual",
+ joints=None,
+ )
+ # Manually add this object to the arena xml
+ self.merge_assets(marker)
+ table = find_elements(root=self.worldbody, tags="body", attribs={"name": "table"}, return_first=True)
+ table.append(marker.get_obj())
+
+ # Add this marker to our saved list of all markers
+ self.markers.append(marker)
+
+ # Add to the current dirt path
+ pos = self.sample_path_pos(pos)
+
+ def reset_arena(self, sim):
+ """
+ Reset the visual marker locations in the environment. Requires @sim (MjSim) reference to be passed in so that
+ the Mujoco sim can be directly modified
+
+ Args:
+ sim (MjSim): Simulation instance containing this arena and visual markers
+ """
+ # Sample new initial position and direction for generated marker paths
+ pos = self.sample_start_pos()
+
+ # Loop through all visual markers
+ for i, marker in enumerate(self.markers):
+ # If we're using two clusters, we resample the starting position and direction at the halfway point
+ if self.two_clusters and i == int(np.floor(self.num_markers / 2)):
+ pos = self.sample_start_pos()
+ # Get IDs to the body, geom, and site of each marker
+ body_id = sim.model.body_name2id(marker.root_body)
+ geom_id = sim.model.geom_name2id(marker.visual_geoms[0])
+ site_id = sim.model.site_name2id(marker.sites[0])
+ # Determine new position for this marker
+ position = np.array([pos[0], pos[1], self.table_half_size[2]])
+ # Set the current marker (body) to this new position
+ sim.model.body_pos[body_id] = position
+ # Reset the marker visualization -- setting geom rgba alpha value to 1
+ sim.model.geom_rgba[geom_id][3] = 1
+ # Hide the default visualization site
+ sim.model.site_rgba[site_id][3] = 0
+ # Sample next values in local marker trajectory
+ pos = self.sample_path_pos(pos)
+
+ def sample_start_pos(self):
+ """
+ Helper function to return sampled start position of a new dirt (peg) location
+
+ Returns:
+ np.array: the (x,y) value of the newly sampled dirt starting location
+ """
+ # First define the random direction that we will start at
+ self.direction = np.random.uniform(-np.pi, np.pi)
+
+ return np.array(
+ (
+ np.random.uniform(
+ -self.table_half_size[0] * self.coverage_factor + self.line_width / 2,
+ self.table_half_size[0] * self.coverage_factor - self.line_width / 2,
+ ),
+ np.random.uniform(
+ -self.table_half_size[1] * self.coverage_factor + self.line_width / 2,
+ self.table_half_size[1] * self.coverage_factor - self.line_width / 2,
+ ),
+ )
+ )
+
+ def sample_path_pos(self, pos):
+ """
+ Helper function to add a sampled dirt (peg) position to a pre-existing dirt path, whose most
+ recent dirt position is defined by @pos
+
+ Args:
+ pos (np.array): (x,y) value of most recent dirt position
+
+ Returns:
+ np.array: the (x,y) value of the newly sampled dirt position to add to the current dirt path
+ """
+ # Random chance to alter the current dirt direction
+ if np.random.uniform(0, 1) > 0.7:
+ self.direction += np.random.normal(0, 0.5)
+
+ posnew0 = pos[0] + 0.005 * np.sin(self.direction)
+ posnew1 = pos[1] + 0.005 * np.cos(self.direction)
+
+ # We keep resampling until we get a valid new position that's on the table
+ while (
+ abs(posnew0) >= self.table_half_size[0] * self.coverage_factor - self.line_width / 2
+ or abs(posnew1) >= self.table_half_size[1] * self.coverage_factor - self.line_width / 2
+ ):
+ self.direction += np.random.normal(0, 0.5)
+ posnew0 = pos[0] + 0.005 * np.sin(self.direction)
+ posnew1 = pos[1] + 0.005 * np.cos(self.direction)
+
+ # Return this newly sampled position
+ return np.array((posnew0, posnew1))
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/bins_arena.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/bins_arena.xml
new file mode 100644
index 0000000000000000000000000000000000000000..5e93098acbb30bc267c5e8c6d93f201a0414abde
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/bins_arena.xml
@@ -0,0 +1,74 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/empty_arena.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/empty_arena.xml
new file mode 100644
index 0000000000000000000000000000000000000000..cab8bc8198fb4df90cd9e3a3b9868751857ff4c4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/empty_arena.xml
@@ -0,0 +1,35 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/multi_table_arena.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/multi_table_arena.xml
new file mode 100644
index 0000000000000000000000000000000000000000..9c5cd3a54e7d5b6912fd179ae441296db07a1ddf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/multi_table_arena.xml
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/pegs_arena.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/pegs_arena.xml
new file mode 100644
index 0000000000000000000000000000000000000000..a3a3211860cb03550617b68d2a43e2fdfcd469e8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/pegs_arena.xml
@@ -0,0 +1,63 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/phantom_table_arena.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/phantom_table_arena.xml
new file mode 100644
index 0000000000000000000000000000000000000000..c99b5829c38732853393f16d40eecf03460e9d35
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/phantom_table_arena.xml
@@ -0,0 +1,60 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/table_arena.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/table_arena.xml
new file mode 100644
index 0000000000000000000000000000000000000000..1c77448f67c23fcac3980a220149ef41efb84c0f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/arenas/table_arena.xml
@@ -0,0 +1,52 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/base.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/base.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b52d438e453ba0a2d808052c8b85f5d0026860bb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/base.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/PEDESTAL.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/PEDESTAL.dae
new file mode 100644
index 0000000000000000000000000000000000000000..1a2f05b303d45623abbcc448f20918d093dd412a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/PEDESTAL.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b0a44b95452421196cccdc1347fbb3f6da6df7a32c5929ba92ae7441b5b1d60
+size 6230755
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/PEDESTAL.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/PEDESTAL.stl
new file mode 100644
index 0000000000000000000000000000000000000000..c40e88cd1aa5e4da35fae1d5225c9eb9750d4dda
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/PEDESTAL.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:405962a9049d58faddfbde642e9ac3fafeead06e1799535eeef4cf01ebd6b25f
+size 3735734
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/pedestal_link_collision.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/pedestal_link_collision.dae
new file mode 100644
index 0000000000000000000000000000000000000000..56bc76595dd29d948d9bbbdfee65d6f2b585921f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/pedestal_link_collision.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b7361cb274df9005294234680b7006aebd4c0067f1515b7a936028cba4c65df
+size 12281
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/pedestal_link_collision.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/pedestal_link_collision.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ec574cfed5b11239176ef327c720f4a8527e6f99
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/base/pedestal_link_collision.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73296b38d0d267f5d2aa0f8626432807e7e7fc3b6aa50263da9d31620d5cba1d
+size 10284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H0.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H0.dae
new file mode 100644
index 0000000000000000000000000000000000000000..6f5b709e908c27cdf3c8765f56413eb6af699807
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H0.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95f8bc9e80eea217afb4e7b271fa71ce52a7054bbe4b20868c1b889e6bc66f65
+size 334255
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..de85f896b6886b97172f1a323887787c4fde93da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6fc0a1a1bee8949c2247ccf837883af5e31516bc54dede2d962243691c0c8c68
+size 260834
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..e52f83ddac62ad4769b1f4a0d905821a8922892d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dcb09dfdd019b803a0761c367f8acf6f331a0c4984952cc424ecb5a174087422
+size 233396
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..949a5e84ba7069b847bd31e392bed2e907baf633
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/head/H1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96cbc1f7b8c5a7a89927d0a91cf7d3993df6ce43f31e96c6c7baeae0e34d072a
+size 174384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_elbow/E1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_elbow/E1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..da6d38ed1daa5acecd26711a2d7ce8732542dc82
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_elbow/E1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e1b3868d8c43baf307e1cca7e700c8be75c6cd983dace1d76ee2c3b563243e2
+size 796534
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_elbow/E1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_elbow/E1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..bf0788d634e49a84dd3a96e911ad7f8b6412d0c9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_elbow/E1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d82652d80f7e449212fb1528b960e0f6bcae85c3997d4613a3a0e1cda5673db
+size 600284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_forearm/W1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_forearm/W1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..a9e9f195788862b322f3475f5d2bebe6a4fa755a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_forearm/W1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9df525029abbdab8c7ce566d57b6067bcfb804da5a27a6ce43dc59cd2e63ffd4
+size 910723
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_forearm/W1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_forearm/W1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..a83a1f983cd80851fcdb8de3bf16a93a937925db
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_forearm/W1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ebc784f78f7353d7b0287682c1bd24c62bf6d3c87e8ee84ddfb350c45dcc9dc8
+size 687784
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_shoulder/S1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_shoulder/S1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..03291076e90badf09c093d129412049a17092dab
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_shoulder/S1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50231b610a1f01a5f22394942e324ea3e4b62c5762a5201b16bab2a2c0ac5bdb
+size 434567
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_shoulder/S1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_shoulder/S1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..5b96e1bdf50785915965ddeeb8e0224aac1602cd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/lower_shoulder/S1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:acc43d8ad81f46d080ea3915e8a45a5ee719e9c94d66622c1a5fb7adc35facc6
+size 328784
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link.dae
new file mode 100644
index 0000000000000000000000000000000000000000..259133a4678951d9a2342ca517b12df60b5139ff
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5067ef42e5575227ff0c2235215c9e1bea804dd62dc87c024a714e03a9d7220e
+size 5037837
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..05cb991be1b7d7f607c0578816f57bed7ba25907
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:553cb6a7e4a445d08b3b24f27ebc727cfc0d7602537c13fa7d27fa226dec790b
+size 3463084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link_collision.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link_collision.dae
new file mode 100644
index 0000000000000000000000000000000000000000..4e181e71e730a4f5dc6c520e6a4b0e0026414986
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link_collision.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0303ca3d47d90cc70969e540a234356a551c5e9d62cbccec35d3bd39c11ad50
+size 703067
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link_collision.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link_collision.stl
new file mode 100644
index 0000000000000000000000000000000000000000..369ffe19ab92bce8d47f31ffaddfff0fffd2d485
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/torso/base_link_collision.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2980b03a0e699f0d85dd4ce3074758ee7067de4b81409b5fa1db4e3078c8b58e
+size 458034
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_elbow/E0.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_elbow/E0.dae
new file mode 100644
index 0000000000000000000000000000000000000000..71505fcbc21080075810f9aadcf1f4c3555d4c38
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_elbow/E0.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f26641362fba070757907374cf59cf3052dc4d9b133090552c2f83b6cf46705
+size 877438
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_elbow/E0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_elbow/E0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..24da66fc3989fd717136338504e79e27df3ad009
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_elbow/E0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6685cc363e0ad06afeb49ea4a462f76ffdf65f39308e2f86ff179855a77d34a4
+size 656284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_forearm/W0.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_forearm/W0.dae
new file mode 100644
index 0000000000000000000000000000000000000000..28228dc7609b2eacdd5aa229e1629966f645e839
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_forearm/W0.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f32b0aa2010f8b021bd649166207eb86250d19cee8c1ea25dbf8cf8ee41eecf7
+size 1814264
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_forearm/W0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_forearm/W0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..1d5f2315eb7d537d2fe2b632e9f2527774ee914c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_forearm/W0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:873f7caca17829a2f11f110cbc04aa051c7b221f5b96b04d2db723e7c8adf9f5
+size 1316034
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_shoulder/S0.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_shoulder/S0.dae
new file mode 100644
index 0000000000000000000000000000000000000000..e21cfb985f05fb7066386e57e012795c4c2cb9d4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_shoulder/S0.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97ce864d405db2de44ddd73689fe8e60eec16cc307daf67ee2d6bb886eb9b161
+size 2640134
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_shoulder/S0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_shoulder/S0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..36489e53cc80378e0cd7f350e6bface6f99f9dfa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/upper_shoulder/S0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b66b83af3303764224e2a8ec1b6257082a572c6eb04ddd4ab4a70dbfea1e0017
+size 1917534
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/wrist/W2.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/wrist/W2.dae
new file mode 100644
index 0000000000000000000000000000000000000000..50a054099069dc3e5766f4155f6ad5ca350947b9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/wrist/W2.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a6cc33e009acc6f3f408008839f5a7e208cfb6caa87e8234425279cc726b109
+size 925929
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/wrist/W2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/wrist/W2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..62ae582181ced7e2289fa084bc16ceffad994b52
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/meshes/wrist/W2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17391c5fbee065d6e39feb2b3f77b9e8a0b12a0cd471c4cd4277844a0492ab6d
+size 566584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/urdf/baxter_arm.urdf b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/urdf/baxter_arm.urdf
new file mode 100644
index 0000000000000000000000000000000000000000..b0e2aa6977a6122dba14f55a596277c554c914fa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/baxter_description/urdf/baxter_arm.urdf
@@ -0,0 +1,1546 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+
+
+ transmission_interface/SimpleTransmission
+
+ EffortJointInterface
+
+
+ EffortJointInterface
+ 1
+
+
+
+
+
+
+ /robot
+
+
+
+
+
+ true
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+
+ 30.0
+
+ 0.0 0.21 0.0 0.0 -0.8 -1.570796327
+ 1.3962634
+
+ 800
+ 800
+ R8G8B8
+
+
+ 0.02
+ 300
+
+
+ gaussian
+ 0.0
+ 0.007
+
+
+
+ true
+ 0.0
+ head_camera
+ image
+ camera_info
+ head_camera
+ 0.07
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+ /cameras
+
+
+
+
+
+ 30.0
+
+ 0.0 0.0 0.0 0.0 -1.570796327 1.570796327
+ 1.3962634
+
+ 800
+ 800
+ R8G8B8
+
+
+ 0.02
+ 300
+
+
+ gaussian
+ 0.0
+ 0.007
+
+
+
+ true
+ 0.0
+ right_hand_camera
+ image
+ camera_info
+ right_hand_camera
+ 0.07
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+ /cameras
+
+
+
+
+
+ 30.0
+
+ 0.0 0.0 0.0 0.0 -1.570796327 1.570796327
+ 1.3962634
+
+ 800
+ 800
+ R8G8B8
+
+
+ 0.02
+ 300
+
+
+ gaussian
+ 0.0
+ 0.007
+
+
+
+ true
+ 0.0
+ left_hand_camera
+ image
+ camera_info
+ left_hand_camera
+ 0.07
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+ /cameras
+
+
+
+
+
+
+ 600
+ 1024
+ /robot/xdisplay
+
+
+
+
+
+ 0.0 0.0 0.0 0.0 0.0 0.0
+
+
+
+ 12
+ 1.0
+ -3.14
+ 3.14
+
+
+ 2
+ 1.0
+ -0.001
+ 0
+
+
+
+ 0.05
+ 50.0
+
+
+
+ 0.00
+ true
+ 100.0
+ /robot/sonar/head_sonar/state
+ sonar_ring
+
+ true
+ 100.0
+
+
+
+
+ 0.0 0.0 0.0 0.0 0.0 0.0
+
+
+
+ 1
+ 1.0
+ -0.5
+ 0.5
+
+
+
+ 0.004
+ 0.4
+
+
+
+ 0.005
+ true
+ 100
+ /sim/laserscan/right_hand_range/state
+ right_hand_range
+
+ true
+ 100.0
+
+
+
+
+ 0.0 0.0 0.0 0.0 0.0 0.0
+
+
+
+ 1
+ 1.0
+ -0.5
+ 0.5
+
+
+
+ 0.004
+ 0.4
+
+
+
+ 0.005
+ true
+ 100
+ /sim/laserscan/left_hand_range/state
+ left_hand_range
+
+ true
+ 100.0
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/CMakeLists.txt b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7282c39947f348162715e7e829a4aa5eb0f6e2b9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/CMakeLists.txt
@@ -0,0 +1,12 @@
+cmake_minimum_required(VERSION 2.8.3)
+project(franka_description)
+
+find_package(catkin REQUIRED)
+catkin_package(CATKIN_DEPENDS xacro)
+
+install(DIRECTORY meshes
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+)
+install(DIRECTORY robots
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/mainpage.dox b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/mainpage.dox
new file mode 100644
index 0000000000000000000000000000000000000000..941d0bf97714038b58c39c7dcd816e901e22c023
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/mainpage.dox
@@ -0,0 +1,6 @@
+/**
+ * @mainpage
+ * @htmlinclude "manifest.html"
+ *
+ * Overview page for Franka Emika research robots: https://frankaemika.github.io
+ */
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/finger.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ef5e672efbb990561b36fcee2c15b2f61cf42065
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/finger.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2d07a740392f3b9b0816f65d64fff9927d3d57c897870fc4b6ff9c56fff3a0c8
+size 1684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/hand.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/hand.stl
new file mode 100644
index 0000000000000000000000000000000000000000..bb315217a60e27343b84a9d4e3a4686762c4fc8d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/hand.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94493e94f30fe940f2c8ca2f155c3bbe67bbff406d3edf5e261670d2f0f6e2ed
+size 10084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..bbe58384ff30b933eb8758429c4f5cbd970c1b50
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfc6d94330de8ddb005b311bfdba9f3b8e1aa7c256b71592ee7ff32cb9a9a5aa
+size 10084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..b7e855112619448e2cf63660e07387871de542ac
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e41a39a94108fcf56aacff603fc91ec80541f4c1af17b51a0de5617f5566e6d2
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..6ba548f4137d4ba09e7b0d9299fa631b27af1ea1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:370f7605a0fae3529db169ded50f52f171024aa792d4d773bc84197301f6a039
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link3.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link3.stl
new file mode 100644
index 0000000000000000000000000000000000000000..7115ba0e92d33fd3a2e6e2087df980ae8b9a6730
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link3.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0a8d638b9349c6c0eefc4e888636ac4838c4b27170f18a51699321118af709c1
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link4.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link4.stl
new file mode 100644
index 0000000000000000000000000000000000000000..88c6db70bf3c3b68bce08b9bcb5142050b1f9079
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link4.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0180ebb5772ec9840cb049750cffb29a9ddc90311752a16ea34757782ef9e48d
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link5.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link5.stl
new file mode 100644
index 0000000000000000000000000000000000000000..5eaf5c8ec2155135ab9297d51e3dae6e5e280675
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link5.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd17e688c7870e722283525879643d53a74c0024d328b0e14b034b54c8b6c31a
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link6.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link6.stl
new file mode 100644
index 0000000000000000000000000000000000000000..828ad3bd384b22ef734d8add0e50d6ae449dce9c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link6.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20b768e99a0e0440b5754dcca108016434e57937cc356acd9c352ccd3cb27f77
+size 10084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link7.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link7.stl
new file mode 100644
index 0000000000000000000000000000000000000000..2047756ec662f051af90fe61266998bf16e655fe
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/collision/link7.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92ac6afcf7574c034d3170d8a68e95ac9048ab9d0dd5bbd8311b86e551b9ab1c
+size 10084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/finger.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/finger.dae
new file mode 100644
index 0000000000000000000000000000000000000000..06f2732f3d24135ff2c06b0217fc328925cb40b5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/finger.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9d5471e9cdf76f85493df6fe6c76d29810edd21577c204009160496d5b77fbb
+size 51123
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/hand.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/hand.dae
new file mode 100644
index 0000000000000000000000000000000000000000..073dd433d2574db0060d31b8f03f26a4f5851bee
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/hand.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9602bbad114cfefd4bbb0cd0a57f3021bce9efbaaa1604a5e72bfb76926bb019
+size 548949
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link0.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link0.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2ea2fbee592e033a1cdf431400ad8c2ac4091248
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link0.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d9ceaa66bb3a734e3a32f2f737ae57a29e922f4a962ed77b9bb8d8e25cd33159
+size 1590896
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2030d60bcfb6e3f00c6287fbe8cad91be27118f0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c8b7b7c1217d620a811fc0ee52d1d1b0e1470de955e7453872aac3f15cf7c5e
+size 978415
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link2.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link2.dae
new file mode 100644
index 0000000000000000000000000000000000000000..64981bd82b803b79f46005dd56885b4ae01e9d87
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link2.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c44d0364f0030007e427106a4e842d835ca43902716cc46ee4f3342dab189e12
+size 998486
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link3.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link3.dae
new file mode 100644
index 0000000000000000000000000000000000000000..23d6124df5e3d5696241441fbdf7dcfe53dbe150
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link3.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dab39a126153fb82f3650cca6de63a8e978f851aa5020a8e91b3d9d548dbba3d
+size 1099651
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link4.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link4.dae
new file mode 100644
index 0000000000000000000000000000000000000000..0ce1680db10d42992cb781fa23e3b5db43dea3ff
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link4.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e03d680e3a4a4555d673bcb8cb466e479f5cb069a5fc8a0b0f99c089c50fd63
+size 1145491
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link5.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link5.dae
new file mode 100644
index 0000000000000000000000000000000000000000..b6911ff709357cc25b27355fc36873d1d30f9cc1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link5.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0be76681192578a14d6ace89527e4ee418f7395e825e285125e05fd998d24e3e
+size 1438169
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link6.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link6.dae
new file mode 100644
index 0000000000000000000000000000000000000000..adac012b16351aecef432a28bd593edc0872a9ae
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link6.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed9c57432b079d55b9954775f2ddfe34e8b904f683949b8eb6314238f8afa46e
+size 1727767
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link7.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link7.dae
new file mode 100644
index 0000000000000000000000000000000000000000..b6d289bc5b7d51793fd2bf805695356eed8ae3ac
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/meshes/visual/link7.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71be614f734bd27b2d7dec3e8bb022251cbbfce38b0a12dbfc1b88bc0513822a
+size 935952
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/package.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/package.xml
new file mode 100644
index 0000000000000000000000000000000000000000..6db900392f56e7b1529f93dd5304874fbe38a4f9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/package.xml
@@ -0,0 +1,17 @@
+
+
+ franka_description
+ 0.7.0
+ franka_description contains URDF files and meshes of Franka Emika robots
+ Franka Emika GmbH
+ Apache 2.0
+
+ http://wiki.ros.org/franka_description
+ https://github.com/frankaemika/franka_ros
+ https://github.com/frankaemika/franka_ros/issues
+ Franka Emika GmbH
+
+ catkin
+
+ xacro
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/rosdoc.yaml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/rosdoc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..96ee597ef5cacd0f223f676c738f396f0810b78c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/rosdoc.yaml
@@ -0,0 +1,2 @@
+- builder: doxygen
+ javadoc_autobrief: YES
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/hand.urdf b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/hand.urdf
new file mode 100644
index 0000000000000000000000000000000000000000..59b9548ac779bfb0369344c50d2b5e2f1900c1c6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/hand.urdf
@@ -0,0 +1,76 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/hand.urdf.xacro b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/hand.urdf.xacro
new file mode 100644
index 0000000000000000000000000000000000000000..643fc608f05f2bd79f0212e01e1a01086d02bd57
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/hand.urdf.xacro
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/hand.xacro b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/hand.xacro
new file mode 100644
index 0000000000000000000000000000000000000000..3f3a209faff8476ee4521a812c5d28314fc8de51
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/hand.xacro
@@ -0,0 +1,80 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm.urdf b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm.urdf
new file mode 100644
index 0000000000000000000000000000000000000000..2e2b6692b225ba21d0f01af729cd1a0088514b13
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm.urdf
@@ -0,0 +1,213 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm.urdf.xacro b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm.urdf.xacro
new file mode 100644
index 0000000000000000000000000000000000000000..ffd0bf1352da0f059e827233f399bade484d46d4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm.urdf.xacro
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm.xacro b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm.xacro
new file mode 100644
index 0000000000000000000000000000000000000000..452e56804b247f347d38c844b2b4783462c6cf0d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm.xacro
@@ -0,0 +1,217 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm_hand.urdf b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm_hand.urdf
new file mode 100644
index 0000000000000000000000000000000000000000..824e988509d7216dfcb2e2d9fe006fb90e456105
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm_hand.urdf
@@ -0,0 +1,286 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm_hand.urdf.xacro b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm_hand.urdf.xacro
new file mode 100644
index 0000000000000000000000000000000000000000..c2415c2c6ea7827364aeb7c4f2e5e1918c22d51e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/panda_description/urdf/panda_arm_hand.urdf.xacro
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/CMakeLists.txt b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a0a3a886382af73c4d3c5a8c53ec87bb21ed9fe2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/CMakeLists.txt
@@ -0,0 +1,12 @@
+cmake_minimum_required(VERSION 2.8.3)
+
+project(sawyer_description)
+
+find_package(catkin REQUIRED)
+
+catkin_package()
+
+foreach(dir config meshes params urdf)
+ install(DIRECTORY ${dir}/
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}/${dir})
+endforeach(dir)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/config/sawyer.rviz b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/config/sawyer.rviz
new file mode 100644
index 0000000000000000000000000000000000000000..edb9b7dfcd14be2ba6a4efd0f8c80141f8ef92df
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/config/sawyer.rviz
@@ -0,0 +1,216 @@
+Panels:
+ - Class: rviz/Displays
+ Help Height: 78
+ Name: Displays
+ Property Tree Widget:
+ Expanded:
+ - /Global Options1
+ - /Status1
+ Splitter Ratio: 0.5
+ Tree Height: 728
+ - Class: rviz/Selection
+ Name: Selection
+ - Class: rviz/Tool Properties
+ Expanded:
+ - /2D Pose Estimate1
+ - /2D Nav Goal1
+ - /Publish Point1
+ Name: Tool Properties
+ Splitter Ratio: 0.588679016
+ - Class: rviz/Views
+ Expanded:
+ - /Current View1
+ Name: Views
+ Splitter Ratio: 0.5
+ - Class: rviz/Time
+ Experimental: false
+ Name: Time
+ SyncMode: 0
+ SyncSource: ""
+Visualization Manager:
+ Class: ""
+ Displays:
+ - Alpha: 0.5
+ Cell Size: 1
+ Class: rviz/Grid
+ Color: 160; 160; 164
+ Enabled: true
+ Line Style:
+ Line Width: 0.0299999993
+ Value: Lines
+ Name: Grid
+ Normal Cell Count: 0
+ Offset:
+ X: 0
+ Y: 0
+ Z: 0
+ Plane: XY
+ Plane Cell Count: 10
+ Reference Frame:
+ Value: true
+ - Alpha: 1
+ Class: rviz/RobotModel
+ Collision Enabled: false
+ Enabled: true
+ Links:
+ All Links Enabled: true
+ Expand Joint Details: false
+ Expand Link Details: false
+ Expand Tree: false
+ Link Tree Style: Links in Alphabetic Order
+ base:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ head:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ head_camera:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ pedestal:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_arm_base_link:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_arm_itb:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ right_hand:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_hand_camera:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ right_l0:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_l1:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_l2:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_l3:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_l4:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_l5:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_l6:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ right_torso_itb:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ right_wrist:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ screen:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ torso:
+ Alpha: 1
+ Show Axes: false
+ Show Trail: false
+ Value: true
+ Name: RobotModel
+ Robot Description: robot_description
+ TF Prefix: ""
+ Update Interval: 0
+ Value: true
+ Visual Enabled: true
+ Enabled: true
+ Global Options:
+ Background Color: 48; 48; 48
+ Fixed Frame: base
+ Frame Rate: 30
+ Name: root
+ Tools:
+ - Class: rviz/Interact
+ Hide Inactive Objects: true
+ - Class: rviz/MoveCamera
+ - Class: rviz/Select
+ - Class: rviz/FocusCamera
+ - Class: rviz/Measure
+ - Class: rviz/SetInitialPose
+ Topic: /initialpose
+ - Class: rviz/SetGoal
+ Topic: /move_base_simple/goal
+ - Class: rviz/PublishPoint
+ Single click: true
+ Topic: /clicked_point
+ Value: true
+ Views:
+ Current:
+ Class: rviz/Orbit
+ Distance: 2.27867007
+ Enable Stereo Rendering:
+ Stereo Eye Separation: 0.0599999987
+ Stereo Focal Distance: 1
+ Swap Stereo Eyes: false
+ Value: false
+ Focal Point:
+ X: 0
+ Y: 0
+ Z: 0
+ Focal Shape Fixed Size: true
+ Focal Shape Size: 0.0500000007
+ Name: Current View
+ Near Clip Distance: 0.00999999978
+ Pitch: 0.240398526
+ Target Frame:
+ Value: Orbit (rviz)
+ Yaw: 5.87858343
+ Saved: ~
+Window Geometry:
+ Displays:
+ collapsed: false
+ Height: 1016
+ Hide Left Dock: false
+ Hide Right Dock: false
+ QMainWindow State: 000000ff00000000fd00000004000000000000015600000360fc0200000008fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003a00000360000000c600fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261000000010000010f00000360fc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003a000003600000009e00fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000006b50000003efc0100000002fb0000000800540069006d00650100000000000006b50000024400fffffffb0000000800540069006d00650100000000000004500000000000000000000005590000036000000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
+ Selection:
+ collapsed: false
+ Time:
+ collapsed: false
+ Tool Properties:
+ collapsed: false
+ Views:
+ collapsed: false
+ Width: 1717
+ X: 203
+ Y: 35
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/launch/test_sawyer_description.launch.test b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/launch/test_sawyer_description.launch.test
new file mode 100644
index 0000000000000000000000000000000000000000..37e9393501e5538492a09d45448ef14d7e96ce9d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/launch/test_sawyer_description.launch.test
@@ -0,0 +1,19 @@
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/base.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/base.dae
new file mode 100644
index 0000000000000000000000000000000000000000..79993c4d499be9416a86e6f5abdcb35845fb7ef4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/base.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f879eb086ce265c73fa16b6c39b9344be2545e293fabbc4cd37ae405f991101
+size 1000898
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/base.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/base.stl
new file mode 100644
index 0000000000000000000000000000000000000000..83382c3684a5416298a7889732f86b57a5635efd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/base.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97b3997a278b7d7be42142f49f435a4a0d7856736b943bfb3590dc43210055f6
+size 264934
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l0.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l0.dae
new file mode 100644
index 0000000000000000000000000000000000000000..8f39b917258adab2e893e8ea9fef2abb8e0f1bb6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l0.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7d5ef27f6f53e5ab78ee0791d28f47099753ab5966132f67339e3514346eefd8
+size 3760739
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..263bfce5dbb757036183d9c849615e4508891ba9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f4a7f56bcf4cbfbb72414acd61956ccb8db88ca5ec4074ff626a44cf41a18c1
+size 675584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..253e63fd21bb6796d2357984078aca396175869b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f500967d2a91c6321a453dc3a5949013729354d04c8de4a8c20116e61918408e
+size 473148
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..964ea6cdf273aa59cb968bec8b71889df8bc5727
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6d8ba089c3da8a4a40e176a13928e3a39cfffa2fdc83311ac2f6b59035ab6d0
+size 511884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l2.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l2.dae
new file mode 100644
index 0000000000000000000000000000000000000000..b83494d7fb80655f2e0d5292b1eea479e1b74ca2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l2.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a56dc6354f42d850b35d246266de67c0fb22eb17576e19712340f5909116aed
+size 654737
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..3ef2de15c32d518d145e0d76a78894fd17d563dd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a0d2ccf5668f737409d0e7fc2578b9e24660298d09787852494a7adae8c58b1
+size 133734
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l3.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l3.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2a6630f83b9739810d4bc5d6506ceefaa28a5587
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l3.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84f67d46fcc4acafd8fd7bd620ef6354b884d3f96eb4d09e792b40ce9ead0669
+size 618017
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l3.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l3.stl
new file mode 100644
index 0000000000000000000000000000000000000000..f0d4d108e88ba1629d943451db1d68916e810bf6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l3.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9faa1f689135fcda53c50b18f9a5c55bc718bb30790395a20821e96b87a174d0
+size 160034
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l4.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l4.dae
new file mode 100644
index 0000000000000000000000000000000000000000..08e9bafdf91240f1c1213757797b5d33b63e0a91
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l4.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b1742337c74de3ec0841eeb651291a09dadd6cbe84e2b1f26a1c4a99272a1c0
+size 4922491
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l4.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l4.stl
new file mode 100644
index 0000000000000000000000000000000000000000..a7d307ae67fac5efc496faebd8050ebbf73dacaa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l4.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9f7032958ae8feef741a4a073f59f3cf9f8f491d93505d0bc9047269cb2e1f2
+size 208284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l5.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l5.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2e68e1492190a8672d7cee28cc916c5fa9a60cc2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l5.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e1ab84f9f671c9842a5673c4236af79c9809ea19469d447bd156fdd8283027d8
+size 2121615
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l5.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l5.stl
new file mode 100644
index 0000000000000000000000000000000000000000..1ca0d1a4164f61bde4838fef7f86312150c5bc45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l5.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5078849cf3e1ae7790a6014e1ecb51eda84b44d2d91206783bb0b1fd1740e9ff
+size 176534
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l6.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l6.dae
new file mode 100644
index 0000000000000000000000000000000000000000..0a170130340091219653ef8fc94307b7798669ac
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l6.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:096155b75cd14bdb9e921c990f358844c9212178f32034a3fa92ef299039b7dc
+size 1990228
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l6.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l6.stl
new file mode 100644
index 0000000000000000000000000000000000000000..9c369213a5624835906d3f9abc7fa32face88e95
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/meshes/l6.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6f4d3eb3e3b89bd364b7d92cd5fa229582295bb7ed6c37f494f370897151acd
+size 261534
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/package.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/package.xml
new file mode 100644
index 0000000000000000000000000000000000000000..22d3d336d400d4091af0ac8e90a1261903b39153
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/package.xml
@@ -0,0 +1,29 @@
+
+
+ sawyer_description
+ 5.0.4
+
+ Description of Sawyer Robot from Rethink Robotics.
+ This package contains the URDF and meshes describing Sawyer.
+
+
+
+ Rethink Robotics Inc.
+
+ BSD
+ http://sdk.rethinkrobotics.com/intera/
+
+ https://github.com/RethinkRobotics/sawyer_robot
+
+
+ https://github.com/RethinkRobotics/sawyer_robot/issues
+
+ Rethink Robotics Inc.
+ catkin
+
+ robot_state_publisher
+ joint_state_publisher
+ tf2_ros
+ rviz
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/params/named_poses.yaml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/params/named_poses.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d8d30349c8a2913c247ba921be54111f516bf6c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/params/named_poses.yaml
@@ -0,0 +1,12 @@
+# ------------------------------ Sawyer ------------------------------
+named_poses:
+ right:
+ joint_names: ['right_j0', 'right_j1', 'right_j2', 'right_j3', 'right_j4', 'right_j5', 'right_j6']
+ poses:
+ neutral: [0.00, -1.18, 0.00, 2.18, 0.00, 0.57, 3.3161]
+ shipping: [0.00, -1.57, 0.00, 2.79, 0.00, -2.79, 3.3161]
+ head:
+ joint_names: ['head_pan']
+ poses:
+ neutral: [0.00]
+ shipping: [-3.14]
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/urdf/sawyer_arm.urdf b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/urdf/sawyer_arm.urdf
new file mode 100644
index 0000000000000000000000000000000000000000..7d483774a20c7462672ff8eaa70b338e52a834e4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/bullet_data/sawyer_description/urdf/sawyer_arm.urdf
@@ -0,0 +1,234 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/handover/panda_panda/demo.hdf5 b/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/handover/panda_panda/demo.hdf5
new file mode 100644
index 0000000000000000000000000000000000000000..643380a388aeb3c084b1db807142e51baa627620
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/handover/panda_panda/demo.hdf5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c2e4b04b8a46dea44218a733d022c14a8ceda723b338d274795d7274f0eeed24
+size 5045092
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/lift/demo.hdf5 b/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/lift/demo.hdf5
new file mode 100644
index 0000000000000000000000000000000000000000..d1777467a5c63a7bd924ef0d399e06e7f268eb2f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/lift/demo.hdf5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08c060c1dcd30582a4f84cf1b9c2daa009c9c75262e78235eca1bb58d0f84e61
+size 559784
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/wipe/panda/demo.hdf5 b/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/wipe/panda/demo.hdf5
new file mode 100644
index 0000000000000000000000000000000000000000..c8ade1a68fba3d17e50597eaf721c118fce8c11a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/wipe/panda/demo.hdf5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06c164a9fc609d6374b5ecc30e14e3f331e7f18e2ea6d7cb5888b14f16630a5f
+size 48112
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/wipe/panda/models/model_1.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/wipe/panda/models/model_1.xml
new file mode 100644
index 0000000000000000000000000000000000000000..a5b8948c45f648069497f96e28cc3699d99f24d2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/demonstrations/wipe/panda/models/model_1.xml
@@ -0,0 +1,247 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/jaco_three_finger_gripper.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/jaco_three_finger_gripper.xml
new file mode 100644
index 0000000000000000000000000000000000000000..da48f24ae15f9f9f4716a8942a79619dd58307db
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/jaco_three_finger_gripper.xml
@@ -0,0 +1,127 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.dae
new file mode 100644
index 0000000000000000000000000000000000000000..0d8e1bcf055e86f815d5e6bf598ec4eca0e61f4d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73f6662e0e9c47e43b9f3018069cf6285f40e163d3555a6928607f058ed7cffb
+size 141699
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..cfab8c125d4e18b2404e2af4cceea7a1d92e1782
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75491d30baa89763ac7b5cac43ea69aa42a863527805b6166709371c3a7dbfa7
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.obj
new file mode 100644
index 0000000000000000000000000000000000000000..bb9ee71efcac9cec598405d44a94c6dc5e3f72cf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:898c8ca4f5790fe956bde643836338346ae740b7e7f27a33ea0ea81fe2c32027
+size 170213
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.stl
new file mode 100644
index 0000000000000000000000000000000000000000..e511f032175daecd562f67950a904f67ebde33d9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_distal.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:27ce521eedbd5e70b6bfa364a29b357d2b03f15ebf1ffeb08c46bf28ccf7f6ca
+size 97184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.dae
new file mode 100644
index 0000000000000000000000000000000000000000..7a0c2598d4214ef6e10bf4c4c775b5dbc09488e0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce16acaa3571b4023c131c7310744fc28894c9703c9833a3d274bb38a704fa33
+size 200973
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..a33160bac0a49d454b675786838b316a2cbf7c2b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:48583af9d2f3a7234a334faae1c997bc77b0ab74b631d6bbe972be92f94a1597
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.obj
new file mode 100644
index 0000000000000000000000000000000000000000..80fd9f750bfe1d3bf14d206a0833ae58ef306c08
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af3119fe7f58591b81ce9fb70c98228fe7ed4b5d33111b36e1a3724f220fa49b
+size 235345
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.stl
new file mode 100644
index 0000000000000000000000000000000000000000..363b6403b838407c0e9ecef365f34dbcb40e6ef6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/finger_proximal.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e4c3b68fe806f53fbae43503176981be61aca54a396d1aa9f3e10f87bc4e4f5
+size 135584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.dae
new file mode 100644
index 0000000000000000000000000000000000000000..b904643656d3c019560c12d8458ec9de710672fc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:018d65704e40f198c2af21ce69de59b955f13c238875f0df229ac7f8436bb003
+size 2286771
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..a3f5d677bb381e83691b67687eb2a4ea527d06d9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04e4454818bf7b38eaa107b96761eac370f7e7eb506adfa59bf246b02f4c1ef6
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4fb49331561c303f59fd9e70fe456c98b9c83a88
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b12a144371e024b3b16acdb473a2a68eaf939dbed963895be80cb0062af29bc9
+size 2834748
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ffbab3ecdcc1151d9003a80eb6fb53aa54161038
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/hand_3finger.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e1a0ecb85ee2c8213ec197e75e53c30dbfcca06ca114bef1a1a80c798349a22
+size 1395484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.dae
new file mode 100644
index 0000000000000000000000000000000000000000..84a5951ee6c71476b32936883981c593580974d0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf4d3b83ed06116b66c6d33a7e2d5e2d2643b7a822147c44bfbba80b5b27a234
+size 35177
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..7fcc1144e135f419a54d66bf76a29b2b659ecbd0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a727be7aca01ef35192e60aa4db05200a956579a4f7b26d9b232b5def096ec2f
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c5693cf2b2b7f57d3beb88048469bcec72be5bfa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2694fd3dbf1c3341fbc447b298457d93a80a6697aeb3bc6849f7238608dcba93
+size 35173
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.stl
new file mode 100644
index 0000000000000000000000000000000000000000..6ae5c8c3335e4d54343b30fe8fc9959d20fb1419
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/jaco_three_finger_gripper/ring_small.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d60973e9d9ff8c6b05d49a120bab2fc8df42e7270d72e369808e9c1678db5eb
+size 22684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ef5e672efbb990561b36fcee2c15b2f61cf42065
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2d07a740392f3b9b0816f65d64fff9927d3d57c897870fc4b6ff9c56fff3a0c8
+size 1684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_longer.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_longer.stl
new file mode 100644
index 0000000000000000000000000000000000000000..9082187ed0ee299915e8fa4cc26867d2ea125df6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_longer.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8857e33ee59c3a395c3036a11e4f1731c88c203c821bb03ca97737a2d04cba27
+size 2884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..fb26096106f740cb0bddbb967a50ee0c5c8a3ab5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:caefea762f2d18ca9412cf6c2e64e5007ad03571181d790a8e0c828b1b2035cf
+size 51239
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..91e73caca094b6a58d95209ccf1b6f045a36a8be
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54213c34eb8bc8db0d52a3d38c28954dcb7e3d36395d5f78a77e6f5efa5f1d69
+size 432
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c0490203d61752c1bee87f584ff289734d8b7a95
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:105cce7bacc069fd288a55b214fdfe37287d52f26d844940b56f022ad5d4839e
+size 65235
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..5b8512a4d5d2de4c7019ca5fc664dfbf309c6f61
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/finger_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e49613148a74f1ba9b5793813078e5becf15833e0296073dc1c523508be35ae4
+size 31284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand.stl
new file mode 100644
index 0000000000000000000000000000000000000000..bb315217a60e27343b84a9d4e3a4686762c4fc8d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94493e94f30fe940f2c8ca2f155c3bbe67bbff406d3edf5e261670d2f0f6e2ed
+size 10084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..4d3078316aadb6d8ef825e214d4d0214855c9e88
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe5d445d509e44a9bd107ff78f0b03c49752c98e5e0a8ebafddf2b6cf5a8b380
+size 549239
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..94839a4e83a0b567a1833d23c2695282b5c14d6f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:732ba90affaf8debb3a4f3bddbe2bc5438137ed5cf325ce1f5829d0bab5e7b82
+size 1038
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..8ce0a4209c5ce0620b479f3f7d72e8f459010600
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c8fa0939d6143d690838f611716901bbf5220d12068a5a60d76251bd3d24862f
+size 737011
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..440763e579e0bc61fb9986de9a751e09f5269e40
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/panda_gripper/hand_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdeb1a924b8d0f1f997f47d6d8af102c7290a1d5144145efdc39501dc13f0fa1
+size 353984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.dae
new file mode 100644
index 0000000000000000000000000000000000000000..31a7fbc7e1ebbea72c8fc0ad7c5db17cff36b1a6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1ea82518c235326eec78addddb4d21cfb629c0b28299b7382a7f361df4f7d98f
+size 2512513
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..b94e8f669d65d43bbcff3693856082fa67812c62
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fa1d09f88548c514fdb275dd03d09bc0d961875aeb5e016f57740825b1d93c8f
+size 233
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.obj
new file mode 100644
index 0000000000000000000000000000000000000000..cee3c58a18040882fb7831efd6d239b767e763ea
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:38b7adecc394bc2d829dc7de308126a93d53a2de0129318d2cb01fcac3ffa10d
+size 3149293
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ef9368c24b01dc3fa8f1272254133ab5fb4fc119
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a23b14473b4a6c38a7702982b77b3f12a15859e4eb96eb0b4b343e41d08f679f
+size 1758684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.dae
new file mode 100644
index 0000000000000000000000000000000000000000..4cc25cd14b357177bc85c3504711caa2eaa77d3e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d58734eef052462c3832d56bb1ea82621bfd23b3b78ec76b45e4ebcf3453b416
+size 984153
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..43a6db963b31a557c06a46fa706335907de7e316
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50421e5c4910c71f56d578f5f46d5c025a3c607c505644cf72ec21781a05ef55
+size 419
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4a48227847c776763fe9af753e982b57dd0c2c39
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad10083f865da3aabb7d578f62c8f41414a6bca7e6ed198a9cddfff60da132a5
+size 1213713
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.stl
new file mode 100644
index 0000000000000000000000000000000000000000..af9970c234745044aeaaaa7baaf1ffbf0cc1c933
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3112ecc88a650b57a3f20dfcafb28d41cc88639b6e8ac85f3cb8739605a2da39
+size 618984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.dae
new file mode 100644
index 0000000000000000000000000000000000000000..ac663e964f00bbf1ebc9ee78859625ca563230d5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b6517e2e44b503984beee9fb83c7a1e5228f532c943757a00959d7748bd44a7
+size 235303
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..c5ad9787119bae0d504525e5d7787939231632fa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2eae3b81e9059410c438231b86e57f2b78692ff46b4cd665c9ca90cab0f27d88
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c4edffa28e35811c0a803ccec6ecc067580bd6cd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4c772ec1757804654ee972496cc6df4a501fbf4f106aabc916cb43b29c11056
+size 287325
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.stl
new file mode 100644
index 0000000000000000000000000000000000000000..91bed464dbc7d62204f25fbc012d8e9b1192c4d2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a38e5377a4806d6d997efa2afae78d37aee10353e7c59dd991b2f41ade85e6b
+size 148234
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.dae
new file mode 100644
index 0000000000000000000000000000000000000000..58f22e8b17372dd508867638c61b45b2257ed979
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9dd36ca75cd19bf67c8864c2c1a13c802be96c6f9f370f51764cf56c9f732d4b
+size 478054
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..ea8491ab829d563156e2f50e578a016a5f827c34
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f9bf5fa3bfbe71eda7979c64f1280ee6b03ab861e27d863d4bc8ba53e0743fce
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4150d9d4fd2a595b7208ec263140321ecaaccddb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8eab9e9f9565e7e61385ccff06fb11068b20330c96c4cc759347aa3d2388c713
+size 615399
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ff20d7b9d30c61b52d7a04b86598370e26d78c98
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdb611ec499f718b82d894ae51d0babe6500c3fdbadd125143ead59bb7795a55
+size 298434
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_finger.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..c15965c0cfd22fff5f937cc8a0d09f3d4afd2e33
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_finger.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:043901d9c22b38d09b30d11326108ab3ef1445b4bd655ef313f94199f25f57ed
+size 7284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_finger_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_finger_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..893ecc12ce943c24747573eac5d14a52e2d03457
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_finger_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b83c49589f1a1ed2981aa0e6adcba03216536f633b2037d9006bf314b9e57e4
+size 51616
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_finger_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_finger_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..dc22f558729864de277fd1acf24bee78923e9507
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_finger_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be21c5e18c901d4747cbc8fa1a2af15dad7173ff701482a8517e33278432bb21
+size 33984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_knuckle.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_knuckle.stl
new file mode 100644
index 0000000000000000000000000000000000000000..1375e4e8f379978449da10ccbb14dd6555074b67
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_knuckle.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d341d986aa8aca7262565c039c04fb6b4c0f2f5c93443519eb7ca9bfc67ba17c
+size 5484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_knuckle_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_knuckle_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..8be2f8503d756381333de0f139828644e763f1da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_knuckle_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bdca583239b4da0024c1e8d7dec8f2cbd191b228375028832d41f6bab962c0a2
+size 68467
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_knuckle_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_knuckle_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..33365cf8627bf74af9ffd6d09241ec770f5cc2ce
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_inner_knuckle_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d61e9c5304d8015333856e4a26d99e32b695b6ada993253986b1afe8e396ab11
+size 43484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_finger.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..4f19c840371f5e5ece42bde7876de650e9ac15cc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_finger.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd7da1b31b73f1aa1d61b26a582cea10f16f17396d54b8937890fa51547d26b0
+size 11684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_finger_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_finger_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..d77ecc2e2f7c68e2e81899647724c164cff86174
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_finger_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7fe9d55a54b4604e6e4781eeccc5b69bff431f37abad051024e5cef75302768
+size 121033
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_finger_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_finger_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ff31816334d2214b68d5afbee8ed6ca2863d41a4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_finger_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:666a92ee075f6f320ffb13b39995a5b374657cee90ded4c52c23ede20a812f34
+size 76084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_knuckle.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_knuckle.stl
new file mode 100644
index 0000000000000000000000000000000000000000..c2818a6266086849687d93c10833d19be237e36e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_knuckle.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:37c1779f71fb5504a5898048a36f9f4bfce0f5e7039ff1dce53808b95c229777
+size 9784
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_knuckle_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_knuckle_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5f6470195ffb186581ef2451f97137e475ff0e9b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_knuckle_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c6cfb105fb57001f377d986b8870243faa5a02b6d73fa81db32c590cf2f2c73
+size 125864
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_knuckle_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_knuckle_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..d511e7320d717cc6a271ad5f94f05f700fc996f1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_140_outer_knuckle_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f042edc44ede9773e7cad00f7d7354d0b7c1c6a6353fe897b2ca2e6ac71107fc
+size 78384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_base_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_base_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..3ef56dcf17eb02fa9870e22f411d2543de2f49ff
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_base_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:111e37f13a664989dd54226f80f521b32ea0b71c975282a16696b14be7cc9249
+size 86384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_base_link_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_base_link_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..65e501d811be48c67556ef9637dd5f0b2f3bcd54
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_base_link_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f914125f62184d183c4fd30d04a7ea79222d622d584d5d39db0354ef19c982c3
+size 1864825
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_base_link_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_base_link_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..6c39b86af6cd7cac704a703438ff5993e65c065d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_base_link_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74a62de75ae10cf77c60f2c49749b5d11f4c265f8624bbe7697a941fa86f6b3b
+size 1054984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_coupling.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_coupling.stl
new file mode 100644
index 0000000000000000000000000000000000000000..8958e11500cfb541905c4a5179d71947f6895379
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_coupling.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5ca9ffc28ed04193854b005358599dd9c3dc6fa92c8403e661fda94732d9ac25
+size 21184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_coupling_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_coupling_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..05e44f123c43c3bc469852c12f03078cc9f49a8c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_coupling_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1e5f1f30e322107726a1f742120fc953bc2086e8f8eb80d1eb09b249dff63b5f
+size 273992
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_coupling_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_coupling_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..e2289a1e75b5ee9e21d22de667fdb5805011d08b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_140_gripper/robotiq_arg2f_coupling_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4281e83002a25c20dc07c68b8d77da30a13e9a8401f157f6848ed8287d7cce44
+size 160684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_adapter_plate.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_adapter_plate.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e4629051e91c163b78039dadebdc2bbb5a27d0dc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_adapter_plate.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a54704e6690f1c0e020003b83631e6e7aa3de74c7174c2df1edff5fc35d1713e
+size 15002
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_base.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_base.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9019e6a324a847da99faeb374757927dd3917a1a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_base.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5fa5ba3211cf22d21e34723579f536b0bf64d3ac58ad1293a960e19a746d402f
+size 755744
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_0_L.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_0_L.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e863ebb9fc4ef6de615786fc5aed112cad60be00
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_0_L.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c5bf8cfb5f36fa9122cccab8e1ef17089eb9816ed1706c3bde65ea29eca3e8f
+size 73327
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_0_R.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_0_R.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4a5de0f295eb1e67a2029d20d95c58cf10582928
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_0_R.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07ea7c7e2a38369cfa927f1121169cbd271f2c4c7a7fffca3b87420ee8b15af5
+size 74351
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_1_L.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_1_L.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6f648790f96039cda0cfe305d36bb537faaa6ba4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_1_L.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17a5979d5569da3141fa20d5a18856beefae656f2fd79f126dcdaf7e187de339
+size 112247
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_1_R.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_1_R.obj
new file mode 100644
index 0000000000000000000000000000000000000000..51febe17fefb6589028a5c99c2970a44964dad13
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_1_R.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8e716745b5de753ba9594cca657e8e5610ed4f7485360c218fecf996dc6883be
+size 114232
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_2_L.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_2_L.obj
new file mode 100644
index 0000000000000000000000000000000000000000..bc6f25d209554cdd70fb0b31591ff76dd952e49c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_2_L.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e77401e1335569a2b9fe55af188768829dba47f5dfc54a982925f39c4fed38ba
+size 78238
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_2_R.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_2_R.obj
new file mode 100644
index 0000000000000000000000000000000000000000..98f57bf4356465c0b8c334b2e8ba6636986e3459
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_2_R.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a0dac0a84b03aa3cdb0482a58d38ce98df1d904d8a919b26a35747273f3b997
+size 79705
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_3_L.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_3_L.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6bfbca7bb208c0994f3f3d0c27f8488e169dfe44
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_3_L.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:596cd27d8736cc273ead163e8c08149ef1c54bdaa3636744762e099f2be43a19
+size 214920
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_3_R.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_3_R.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b49116be495ca40aa5be1ad831f13676dc6b01dc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_85_gripper_joint_3_R.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:629c1639e3fa07d38e3536c1ec19227e3969aebe867ad55fa8e25fa5a027bc7b
+size 227020
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_base_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_base_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..9d4e490e7823efb0eb8b5cffcbc73dd1f8a32826
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_base_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8844d7d09e05423b6edb56b354eef561ad6cf4787d8b7f980232cd4346f46bf
+size 86384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_base_link_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_base_link_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..97f50b6eb33350f752b41c23b2738934deddf572
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_base_link_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d03d6e7395b0aa11ae7954b2f04a0a650448547d5d4f367238098ccd848b3eb5
+size 2523721
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.dae
new file mode 100644
index 0000000000000000000000000000000000000000..5c6319336ab72b67b82743f59ebc365155d03e15
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:807aff52c5a12ca5429cfb2eb19cc88dfaf0083bac6b69e24d49f6beb29aa2c8
+size 21799
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6679ab7c18f58da3537167796792d5c6b1410ebd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f967bf4a720ccdecc85579c82c16058bce41ae29505b6992a147b09915a78735
+size 21387
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..434b257a7f6ad3d511b4ab161a7b1013da2a5975
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2222dbc6ebe5579e718231c4c1766680cfe77297f4086e84c60711ca98571a7c
+size 18484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..daead518e98d02ed6a8ddc5334b976987e5f6df8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3172e8fe000fda75b9b29b48cd8cad4011d184a6df23eb2c3adb6c1f5a4eae93
+size 154078
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..d9af1e883faa8bc2ea0bfc02b4f647b9958186af
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fca72dbbdbdd9de02a96ac3cf693e4f985ce60593bf3fec750120280890bc044
+size 235
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6990e25783e645b539f8bbacf316ec2879b315e0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:58c5003a139496eb48c7a5220cb6c5ee9b72d4b7e99af425028f8d79ca90aaf9
+size 180046
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..241eaadfa5209f4a4e5bb5d2dd198b81862ac470
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:208f8225e5291f6476f7f4ca307d3786808c20a8acb0bb759bd274cc69ea4a37
+size 110484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.dae
new file mode 100644
index 0000000000000000000000000000000000000000..d15b19feb229310001158ca7a8c6f8231de78d83
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8e6bb0b3b5ef6ea1323aa698a16dc1ec1a878ed2c2bcd43547750a18e72d9b4
+size 18425
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a455963857784a604a6dcda14f8abe009ccf712a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1c254381e04eee630a55c81acfc0466443425a8f752f5183737c25f048424ece
+size 16457
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.stl
new file mode 100644
index 0000000000000000000000000000000000000000..67ba4121dd3dcfe8a00587cb13903aa93f9d3d9b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:657af24c1dff19fe9ed7c027fe3aa67448b42ad67f2dd77a71fd8b6f04346ad4
+size 14884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..bb92c2765ddc036caa868ab892b1e5ca729dfc99
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19b12510eec739a1a4655ff633ce3f02ad3c2eda2cb8f451df900ddfa76c8f14
+size 117207
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9ddcdea5a0b1df886643b641dbbbecdb1cd9fb69
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b13a0335f4528e73f97578deecf80966f3d283104137a50542eebff9627348b
+size 136894
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..5dc7bc2e860a8f146dee2dc76bc2e72bbc74d06c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6c5d83c5732e50b5224e1393084d5123ce997f2d3468f97639f48d2257c5e9c1
+size 84884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.dae
new file mode 100644
index 0000000000000000000000000000000000000000..bea4c0c3eaf28a820b0997c1a7661d554c07bfb3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e45e7773260cc4b04bbae626af94bfdd49c38b3c69fc248cf0708e5fa65a39f
+size 21252
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3e791ccfcc4fa645b2e7fda13b0f778155fb4b68
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b404f95e3e5e372206de173dd0e70286a017f065385b0a2df8da89de15cdc492
+size 20584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..2061c441dd1d4642783f356b8eae83c63aa18e7f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7eb5949cb98e382f500ef050537cafd5c8cebcbb753c88d3a6ef06ad42d8ab93
+size 17284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..77d2bf2b4ff3555378903a9005cac2c9b40006af
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:32d842e51e0ddd25b3354af212131befa191c183cf4cdb61f82ecbe5c5276323
+size 120496
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c4bbc0bf4c1d7096472e26b0bba5732f12abc132
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d213fe23f7cb8fcef2b9c6556c9051edd92c2f03ccab49e81e77498aa591da7
+size 132641
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ad91ae4f66536e0497538cb1bd49ef12f103d1e4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e77b6a9227acd546f433c5a29abf9d27848193ac410fd369c2b56c264d274d55
+size 89084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.dae
new file mode 100644
index 0000000000000000000000000000000000000000..0c4814d9ba893ff5ab4828b35dc42d422fc02a35
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c7ce8ee42044149f9a956d55f33f06e6b75e5e35080ee4bf1e8cb381535aeb7
+size 25969
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e5f1275605f663f5839b4d5d3b2c145b58d1871f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f951bf4ef40fbe75e0b2bc0bd713cb8a72d75ea554dc9d62b2ba486d2251c04c
+size 25953
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.stl
new file mode 100644
index 0000000000000000000000000000000000000000..923521ad79ab288bda9a62d9692333d686ac082a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba362290082912a5cc1718d9406886a23aa3c15c6de850eca0ce81c09312bcfa
+size 21084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..4c2a9f4c01e8178bdd3d1d924486f4322b3c71c7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f1cc5198c9f3b06979af07d2944034b32acbb6130df5396dc2c7253cb656f587
+size 69464
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..30ee68439d12479e0711b358207d3c513a76e9e6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba43ef55f4819ef60c79294037dca7f6f7ff1a7eb4d3f62fda728d1f9355dff2
+size 108831
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..52f9a1f822fead442bbe832b536eb6ad5368ae61
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d0949ab29a3def1b35b48b15576ec9adf5315d59a6277efe2ba45a48c3a393d
+size 67084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..29a457b7e183f3170e5258220c21f7533b437167
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0f1d5b8403fed489fb40b389e3d9ec27882db324d77d20cc84cf8bb8901fee80
+size 4136
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..0b5446e63b6a984775f7c690499b0b0d74f958e4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:063a977fb9800b206dc818234f4a2193ff19602366a388b4140442f53463389e
+size 684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_base_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_base_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_base_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_base_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_base_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..27875acad4f68c7c23267b5e2f401698d94437ff
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_base_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:021222c315c9f3318c058654b2407e142c93bdfa7ecf87b11fb67493cbc472ff
+size 110714
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_base_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_base_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..3ef56dcf17eb02fa9870e22f411d2543de2f49ff
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_arg2f_base_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:111e37f13a664989dd54226f80f521b32ea0b71c975282a16696b14be7cc9249
+size 86384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_gripper_coupling_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_gripper_coupling_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..e2289a1e75b5ee9e21d22de667fdb5805011d08b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper/robotiq_gripper_coupling_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4281e83002a25c20dc07c68b8d77da30a13e9a8401f157f6848ed8287d7cce44
+size 160684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/base.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/base.stl
new file mode 100644
index 0000000000000000000000000000000000000000..fc34f505a5f2d462ca7ef782ae866c4de663f06f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/base.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1019b87c1dcff4a08a2fdca2dfd4893d60a5e2fc53512c6b5fe2e372b75c9aa3
+size 2307084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/base_coupling.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/base_coupling.stl
new file mode 100644
index 0000000000000000000000000000000000000000..eaa901b1171bcfa51d7f2c36f62cb8a3407223c2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/base_coupling.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6a2517f9d6d78f89d9edb617fb93a279b4d52ac61c12d9cb743c701676eeb06d
+size 540884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/c-a01-85-open.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/c-a01-85-open.stl
new file mode 100644
index 0000000000000000000000000000000000000000..fda17bd8508a457ddde2de1981d00d33fda21152
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/c-a01-85-open.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:13ad73ef491f6f28b9ed6b8fbb3d6fb45896110ea7d415cda1689fc8daa5d925
+size 283384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/coupler.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/coupler.stl
new file mode 100644
index 0000000000000000000000000000000000000000..c29e887c49050da3e8a4e17092726c6db0e20688
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/coupler.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54949f7355c35c976d854fb77272feb92d9201213e343a8852429556fc81d416
+size 641884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/driver.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/driver.stl
new file mode 100644
index 0000000000000000000000000000000000000000..5da0f469460d9b9e7209896b0584e28d1fe0766f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/driver.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:baf8b4dde18ce59eeebc0928a289c69dccec9da81bb186e2838e2e304274e106
+size 438284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/follower.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/follower.stl
new file mode 100644
index 0000000000000000000000000000000000000000..b1e46dd408ee4ed64ba512e6f9e612b903ea5461
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/follower.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28811d3651345dbb5f2020c67d3bd05f754b5e3e791e379c3c4d1d87418bb9c5
+size 572284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/pad.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/pad.stl
new file mode 100644
index 0000000000000000000000000000000000000000..be08ea411ae9d06d9ab566af28906c702f132176
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/pad.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c0f4d31e867a5b3634c102669b76ac5e8c026ede5fc645b751a5eb3d4bb0be02
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300.stl
new file mode 100644
index 0000000000000000000000000000000000000000..200adcd0dbf6f42b61090bcf3bf6a39eca04ac97
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50be7b765c349abf102d81dc62d8816a39a6345d3e794dec505ec8d5a034b973
+size 5915384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300_base.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300_base.stl
new file mode 100644
index 0000000000000000000000000000000000000000..7715b80d70f9ffa0f36bf162004e60eef0eba0d6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300_base.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1f5aa91c33e900fb10be4d8cb5658d669fe48fdc5dbf2bf61ca03002953b0cf
+size 1104584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300_coupling.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300_coupling.stl
new file mode 100644
index 0000000000000000000000000000000000000000..b31f72d8479199dd4fd97bff7abe0456bb539e18
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300_coupling.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c45d18f0a7aa4a52582bb0eeae061f76902803965dbd7712990627b200eb667d
+size 2006284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300_top.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300_top.stl
new file mode 100644
index 0000000000000000000000000000000000000000..4249b38b2e6cd342db0978605036720fc45aeccf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/robotiq_fts300_top.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e9ca76b0b09c6871e4c4ffb5031a85c19db62e5734ba41d93622322629fc189
+size 4111084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/spring_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/spring_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..1cd4d44f9e96b739b1a3037f7d5ddef9dbb46216
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/spring_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56e9f28ce90841d654ab6d953161aeb62142b1459c210335d5862e7fcb281aab
+size 656084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/tongue.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/tongue.stl
new file mode 100644
index 0000000000000000000000000000000000000000..0e502de31417ddd1ad66b8d18e97be1bc44211d0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_85_gripper_v4/tongue.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2bd5df8a2132542703d40d94be36a990bbc3eabe070283232871147081348d52
+size 383384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..7b8659dca8cfe90a377661a3a110ab4ca93832e0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:51a5dd4d3afb59a319724f9b51ef27616b079d49dabca2f160ae0cc06c9dd7ca
+size 10884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..85eaa49951dd376e78dd0d8577726eb8bf6c955d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e9d3a996e9508ac064e2984f0488eaf9147cc0c547f897ce70aa9c0aa623d27
+size 39300
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..da7a6a0920505546563fe55d9ca16f62e88b67a1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_0_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f555448a7907a16daa7c0462d79aafae97a81c009eb7238460646da690b70fc7
+size 25084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..18d32cfece91797a4b8645658943c2c36e96a3fd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50b05475bd700f1d6418cbc52bbadb952bb1ab876b850f9c36d2cd309d7b78c5
+size 11084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b03b4f5fa89144ed3724cadb11785997915e1c43
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b21997b387c22d130e785a086b7e15686c0d9118add614eadaad37cdd9f76ec
+size 73867
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..10bcba9b46ac9079d5e4263e0cab843db3802ee0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_1_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:32cc3a859be1c5fd8ad7583a41bf619e2ea01d2bfbde94a17c5dd3217576627b
+size 46284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..14ac4e25c65fb153b8f3c2f0257758f940dd6672
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9f1dcebe3eac5d4ab3801d6e2722bbd9e99d9c3e3b2f741f8bd4efd997b605a
+size 10684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ce2e2a4e135352f45aaf6b9af1394d67d60c296d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:68de4ff372a69164fae416134aefe6b32638d2d685c75dea6240ba37c4c85472
+size 69320
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..29ccf85d1e0aed6bd74517e06a641cd70db373b1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_2_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a49ffc1093232d6c2c690c3d59566f512c66198c23231bf47825be143a456f87
+size 43684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3.stl
new file mode 100644
index 0000000000000000000000000000000000000000..d0e5b4383a53028043e53f882b2c2904f05deeaa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db8715e5676185ee507b60a466d3a31622192e7a250bc56286923d9fd1a435a4
+size 15184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..d133dcbe7cf5a3f9679b6477e52d0621fa26bd68
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f1884f923a14621a8b45decc6524ba71e5c31022f9c9cb57352e8cc24682494
+size 59535
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..cf397c85d588b2582c8fbe3b91138c1fdd421bdb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/link_3_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea63028f37f3ff6acd587559e6c48c4510ee766c53df813c7c6cd5796beaf6fe
+size 35484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm.stl
new file mode 100644
index 0000000000000000000000000000000000000000..8486c3b16b20d908f81b45eb40802575e70474ae
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc17ceb8141d8a145af7abc86e89edad55a3a9b4e3a9c8bbbf7283adc8b759d6
+size 133084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..67462553270427a98f0dab21242693b96acebcc7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19b2e0b9dc20cacc505131f9d7b01bf0bcebdb12f3c3f453aa87b7c0fe8a3364
+size 1055346
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..eec508569dfa92159ed5241b4d96521f18730adf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/meshes/robotiq_s_gripper/palm_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11a87fed9f516a6d81473a8ca112e6b976d44383a3375ba6e8c84b7040455141
+size 534284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/null_gripper.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/null_gripper.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b6467381562ad0bf41548ea3d7d14aa370786fbf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/null_gripper.xml
@@ -0,0 +1,21 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..60a118587f2d80893cec7fe8f170717539e7a490
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8dfe6bf6554ea5057c47c1c7a506ea3089f6801300963231c150cb5c283864e
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate.obj
new file mode 100644
index 0000000000000000000000000000000000000000..629e685206ac79ce00cd97dfbabca077633b5ffd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4537b1b8c6d45fbc8c0efcdcce6cdaec14aa9c6a13b1c6eb52adcbf2be9a79a9
+size 2085066
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate/connector_plate.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate/connector_plate.obj
new file mode 100644
index 0000000000000000000000000000000000000000..827ede49ac7a3fe0b5c30ca5eb4532c2ef461503
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate/connector_plate.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:22caf42a3e398e1a737efd1ba0c86788b194126c995c0e7449687709421b384e
+size 7162239
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate/connector_plate.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate/connector_plate.xml
new file mode 100644
index 0000000000000000000000000000000000000000..91b058d047d1f1962d17396c6592da9df457db28
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/connector_plate/connector_plate.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..97c37a579113abad3a3c13ceaa62d68730bf8b47
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a43c875708217c6c96f7825ee317b89d7ce2913af83cab277118deae7f77a8f
+size 427
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base.obj
new file mode 100644
index 0000000000000000000000000000000000000000..d9b61704a1cc0341f7f4238860beeba921a0bbeb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd36692c2590a3f2d6d128a0dffbf7d4f6ac316d0bfd2d9c2f66a2e31e3edd1d
+size 862969
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base/electric_gripper_base_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base/electric_gripper_base_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..bc5dd9de415768cd52699f92187da3bff3128864
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base/electric_gripper_base_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:234baff59359fca1ebd3dea25acb8ac80d28e6a576b322dd86d0c9e6cd8ca4f0
+size 358197
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base/electric_gripper_base_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base/electric_gripper_base_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a5acc7bc46fcf41a91dc706469f5285df8a9eba7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/obj_meshes/rethink_gripper/electric_gripper_base/electric_gripper_base_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:995fcb7934edc215173f403f20870799c9772fd1f314e749d11c46d68318e330
+size 2512399
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/panda_gripper.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/panda_gripper.xml
new file mode 100644
index 0000000000000000000000000000000000000000..fb9463484727e1df64cc3db2be2ddb4a8341a22a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/panda_gripper.xml
@@ -0,0 +1,54 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/rethink_gripper.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/rethink_gripper.xml
new file mode 100644
index 0000000000000000000000000000000000000000..072542a76e7701285956397bde7f9066a7950466
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/rethink_gripper.xml
@@ -0,0 +1,73 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_140.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_140.xml
new file mode 100644
index 0000000000000000000000000000000000000000..e6b7d79bb835c4fe1ae1a5fe2a9c008d4e6c3d2c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_140.xml
@@ -0,0 +1,119 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_85.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_85.xml
new file mode 100644
index 0000000000000000000000000000000000000000..22ec92d5546c1e7a747a3fe177ea1d5330c2ce65
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_85.xml
@@ -0,0 +1,185 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_85_real_kinova.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_85_real_kinova.xml
new file mode 100644
index 0000000000000000000000000000000000000000..d4a24161ea0f7311a935fd4851b2e22d2fbef881
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_85_real_kinova.xml
@@ -0,0 +1,186 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_85_v4.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_85_v4.xml
new file mode 100644
index 0000000000000000000000000000000000000000..13e09b7f53b2cc320283229e501edbd89968ae34
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_85_v4.xml
@@ -0,0 +1,165 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_s.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_s.xml
new file mode 100644
index 0000000000000000000000000000000000000000..26e5c749114241d9343053fdbaa41c6d09732d9e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/robotiq_gripper_s.xml
@@ -0,0 +1,182 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/wiping_gripper.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/wiping_gripper.xml
new file mode 100644
index 0000000000000000000000000000000000000000..829fff7ec3733e777e528ebc70f3885c4702f7f7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/grippers/wiping_gripper.xml
@@ -0,0 +1,73 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/light_maps/photo_studio_01_2k.hdr b/phantom/submodules/phantom-robosuite/robosuite/models/assets/light_maps/photo_studio_01_2k.hdr
new file mode 100644
index 0000000000000000000000000000000000000000..b298836f7ff552d84055d8813bf2255ba3fe3148
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/light_maps/photo_studio_01_2k.hdr
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b6178cf04ea2ac9390b8794d3088a04c0254905f335f5f30d2b582c57c40f387
+size 6375901
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_collision.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_collision.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..04d42877603638d350f9fdb749596bcdcffd9d45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_collision.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e23b5279777ea3bcbaeb3a0c748f95a51d1dc3bcce1db8c035a9361512e73a66
+size 124
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_collision.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_collision.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4f1d07b322563a13b0a87fe0113ff01fa9435cdb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_collision.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b57c84fc95e3497c4554bac47fbdac9b50622083e36010d99ed24604b749d81c
+size 16065
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_collision.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_collision.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ec574cfed5b11239176ef327c720f4a8527e6f99
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_collision.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73296b38d0d267f5d2aa0f8626432807e7e7fc3b6aa50263da9d31620d5cba1d
+size 10284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..31f11409b2145ac7455d93c8d3aa918c51aa7f21
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d3658352da1bf92339c8d5da2a9c83c8eb952779ae351ae3a7c3d675923a1078
+size 139
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..f357a16537a5a203c83bbad4d373d8a46f1f6127
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0fd6a4a0ad76ab9dabad8de389d043e1d1cf856b0d94d8e4bd5a4e476fe774c3
+size 1128212
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..e6b86c57258affa2767b8d6846bce74ab5babbef
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_minimal_mount/pedestal_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73266304fb9e039e2d77051465f5a7697d351326c09b4609b1d8dede6ef5fa6d
+size 501234
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.dae
new file mode 100644
index 0000000000000000000000000000000000000000..3d1b182cfdd00212645903b3980e1597d0cbd76e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9a74be4ae114c9acafccf68e1e49d8fd815ec030012ce1b60bdcf9b30db49f5
+size 2734652
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..7486ceea55881c4a527e51837c38df066a77c9db
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad7c0b8f65e2212a0c70ba09e79ff4729cc288a963bc52d69c0bc0304a683aaa
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.obj
new file mode 100644
index 0000000000000000000000000000000000000000..abb25d449d73b66c81efdf7c3484d4e95a571762
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5db81bfd184f4a20857a238d878cdb73428986767ec7aec5bd629f75ee71c75
+size 4009925
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.stl
new file mode 100644
index 0000000000000000000000000000000000000000..78cd09c62f6cdee80a1698c3077c7dd6d5b543bf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/meshes/rethink_mount/pedestal.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f96ff12bfb347c77580e4392556b15e47ef7228711c45b0f1193b100af22866a
+size 1647684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/null_mount.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/null_mount.xml
new file mode 100644
index 0000000000000000000000000000000000000000..7ef2cc3631ccc42d2f1134d6c7900d9ed86aa0ce
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/null_mount.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/phantom_mount.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/phantom_mount.xml
new file mode 100644
index 0000000000000000000000000000000000000000..f4b25f61dcc72d08ed06a1670eaf2234f574f2ef
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/phantom_mount.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/rethink_minimal_mount.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/rethink_minimal_mount.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b3c31886b5f1a9c98e90ff30efc21f6b823e72ca
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/rethink_minimal_mount.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/rethink_mount.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/rethink_mount.xml
new file mode 100644
index 0000000000000000000000000000000000000000..7fed3a86436dd94cf60601c7af2fa0e3429b5181
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/mounts/rethink_mount.xml
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/bottle.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/bottle.xml
new file mode 100644
index 0000000000000000000000000000000000000000..5b83fb2f0d6f8495997cf8fea1b92e8c45173583
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/bottle.xml
@@ -0,0 +1,18 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/bread-visual.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/bread-visual.xml
new file mode 100644
index 0000000000000000000000000000000000000000..ff20060006c2250896d01e71c10a07db88053245
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/bread-visual.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/bread.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/bread.xml
new file mode 100644
index 0000000000000000000000000000000000000000..a5796c20d2649482a79d17642b00a82512c703bd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/bread.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/can-visual.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/can-visual.xml
new file mode 100644
index 0000000000000000000000000000000000000000..5c9d905ae89d55866b3c65be69b60481f165155e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/can-visual.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/can.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/can.xml
new file mode 100644
index 0000000000000000000000000000000000000000..d9c222b379b5dd01d6b55649b29039b79efb8c76
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/can.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/cereal-visual.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/cereal-visual.xml
new file mode 100644
index 0000000000000000000000000000000000000000..eb399031e23eaa73f578e855daa7a1c9fc42c87c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/cereal-visual.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/cereal.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/cereal.xml
new file mode 100644
index 0000000000000000000000000000000000000000..708de0f5f374afe0e6be7e955043145bc8f5f16c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/cereal.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/door.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/door.xml
new file mode 100644
index 0000000000000000000000000000000000000000..5ef1c5a685f747ff7e9205965b19f924bcac9336
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/door.xml
@@ -0,0 +1,40 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/door_lock.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/door_lock.xml
new file mode 100644
index 0000000000000000000000000000000000000000..5337073f3c14f48f15d19baaa345ad38267afb82
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/door_lock.xml
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/lemon.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/lemon.xml
new file mode 100644
index 0000000000000000000000000000000000000000..6b2c6a2f71daede6a1a2384052f7c3f34319e8ba
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/lemon.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.msh b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.msh
new file mode 100644
index 0000000000000000000000000000000000000000..54a9d787168fbbda5d14ea4c8144ae5dc153e469
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.msh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:16203cc4564fe8dc2c1e5a6bdbd20ff87561f016f3b11a2fb6445e144cf80fa6
+size 14272
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..dbc0ceb4c124304b234285c1dba173ed9befb5fb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e427f17d7d05eb45adaf1e2bd7fd249bc3fe56437ba8e348cbdf147cd384b996
+size 238
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.obj
new file mode 100644
index 0000000000000000000000000000000000000000..fb0198196513ba778408d536fcd511756261338b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92353bf012eaea33c1e0d78d391fa61e083d53c57325e68d3410b31649c43da9
+size 8999
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.stl
new file mode 100644
index 0000000000000000000000000000000000000000..cd2419eb4fd7fcbb074c214c6c0dfdf3168e4851
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bottle.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f2b6ca097b2d0d43b255fe3d605c8628117459b1d7369220ddf62ced5e84d962
+size 6684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.msh b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.msh
new file mode 100644
index 0000000000000000000000000000000000000000..93a48ecfb152f84765df8d80f6e9914b54b271fb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.msh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:befe24f2e5876b3bf7df07c92622c52691ec818d73c9594cb68b6fc43a1cb8d3
+size 11248
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..b0b6cb8dc80f989a843c100aaf94cce73b4fbe3c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d776116483420a01ef53fd8f941583255a56467f521768027bfcd0123b7be527
+size 265
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.obj
new file mode 100644
index 0000000000000000000000000000000000000000..f8424f309bdf8f6bf7c6e929e0f013cdab2937a7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb11db62e9da167ebf91ff4bed4d07a2f8d8210d5333df1bf01e18217a37cc1a
+size 8026
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.stl
new file mode 100644
index 0000000000000000000000000000000000000000..1215c89ce1a8337fcda6d45febbe9e6a7c885237
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/bread.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c6646bcf9ec086d972d4d89141e8bb8081f14170306376bae59d3be59a6ace2
+size 5284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.msh b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.msh
new file mode 100644
index 0000000000000000000000000000000000000000..e17dffc422fd0ba18764f53677c8ce0297fc65a0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.msh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a0e736201604130d47abe6f93e9173b11b1d49df6e2f7093537a435c061cbb8
+size 103264
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..aedbd504433cc686ed638d5d13edfaf5a32b39c4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:292e64f3385432e11d55dae641f29b6855e4fe4e2d1d48baf7e7ea218a5e9d85
+size 264
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.obj
new file mode 100644
index 0000000000000000000000000000000000000000..2dc699c438431007ee8815cb6b3bcb32db225d69
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1904183501b49346131ed46bd1ee6fbf2c14f394c233ef9df048758c9a4b5118
+size 77216
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.stl
new file mode 100644
index 0000000000000000000000000000000000000000..b848dab76d94d2e184388ba8adae65470b3b0893
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/can.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cb459d8f0dcfc855aadc91815462d69ba58b4f5bcc289631043f17eccbf937d3
+size 47884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.msh b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.msh
new file mode 100644
index 0000000000000000000000000000000000000000..671c657da225036af7c2931f4ec5a5570ad63b29
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.msh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:baee1abe4b11299a3cacab3db566273953f6444241c92a32ba3c8dcbf7fe8d5d
+size 12328
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..a5dc698ebc28f604917a9624364da4da959a1adb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:485789398b0ef6b1c147956be55419ea06396686e69a4fb244bdbcdededda341
+size 266
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e943e328881aa8e202549f442c84af21277cbdfb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8377fc93fca888314fe6a47bfe392fc8eb6db8f5c0cb796ffce8f9e47f589bff
+size 11709
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.stl
new file mode 100644
index 0000000000000000000000000000000000000000..99560a13d2a6df746a78c60eeb0a303c0258041c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cereal.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd415818056831cb6fa4b40112280f463de8624b06e23d188a7e0f83ae5be1df
+size 5784
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cube.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cube.obj
new file mode 100644
index 0000000000000000000000000000000000000000..10ba1dc72aab49609f02593bf2fb09d61e205b72
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cube.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56a9cb677cb4f267f49c879af58db1d24d0160688781c90498bdd6a879e87dd4
+size 788
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cylinder.msh b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cylinder.msh
new file mode 100644
index 0000000000000000000000000000000000000000..a32038a621d6aa0af121f3f3db35ecd6bffae47f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cylinder.msh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41061b9ceee866bcc115470d0e3fca91f013f2b1b93b5aa03f8357a9025c37a3
+size 27232
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cylinder.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cylinder.obj
new file mode 100644
index 0000000000000000000000000000000000000000..110f6e974c245f291a8f6b32e396e69769ec71a0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/cylinder.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6d8c3cbd3c7e45f6fbb324c2c890a0fce966196813e2f4567356c814c053a56e
+size 23189
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.msh b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.msh
new file mode 100644
index 0000000000000000000000000000000000000000..f5843d937342345321d686f599fb44863f9cf26c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.msh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84d98eb8a8e406de6745280af4461b13b9e4e1475451aa27674de3ae27c5a486
+size 14272
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..dbc0ceb4c124304b234285c1dba173ed9befb5fb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e427f17d7d05eb45adaf1e2bd7fd249bc3fe56437ba8e348cbdf147cd384b996
+size 238
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.obj
new file mode 100644
index 0000000000000000000000000000000000000000..24bd3cb5821deceb45c618424ce2c507d0b119e0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a14962cb87de19cfa50bfc8d0fff4aab01d82f0bb7460b0894770cc635ff635d
+size 9695
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.stl
new file mode 100644
index 0000000000000000000000000000000000000000..39d0590d0fa05b180aea24a1ae6dc487ade165fe
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/handles.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5564bb9bfc069a3834f87e6761c5e5797856fe8b932928aa19efebb1568625a
+size 17484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.msh b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.msh
new file mode 100644
index 0000000000000000000000000000000000000000..6dee42b15420cc45068b92ac3d696e2f5750e36a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.msh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a2dfa73cb1e6650180bdb67fe253c565c409cdc6d89f70937017afe5555780b
+size 57040
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..4909d82692b41142528425095c2c1e5eaa8b71e2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e4457f44c4aacbe345e57c7e73ed87656f06b3dfba225e32ed427d2b79c79d34
+size 265
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b139f4392c19f6b1827ee41a6afcc16bc0db037a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97cbc7113211f9525427cfd9087b891fb32eef55a5f691b1227c4f498e5bd99e
+size 51035
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.stl
new file mode 100644
index 0000000000000000000000000000000000000000..59f92681d6f5d877960ef3ad45963cd1c2500c8a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/lemon.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bcaa216283267fff9621f524d609b59aad279971dbb52cb263680f03fc1f79db
+size 26484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.msh b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.msh
new file mode 100644
index 0000000000000000000000000000000000000000..ea137bf28fb3148e04d4514b59fdaebbdc799f80
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.msh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92dbdf689a2fc2a46ddf3d5390cac3bde7fba9bfec642e6693acf205a201a40d
+size 26152
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..054a76bedc1ca9f0da9ff303e942b3f71d1602fe
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e3a9c63bea5ebf37805dcd187627773757d056bf82fed26e4c0298c99c3b21ef
+size 267
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.obj
new file mode 100644
index 0000000000000000000000000000000000000000..fb85f09df8e0fc74b0a096dab13ee4f4b6042c26
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:15f99744f9c5a73e50750f1f3e44208d084bff5686f96281349ed73ab38666c8
+size 20373
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.stl
new file mode 100644
index 0000000000000000000000000000000000000000..05a17aadc8a973b2c49f5fe1c344ec76f21cdfb2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/milk.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7880b8da658d7612990cbafa83fb8d358005319a5cab7be6db6993b4b830b6b
+size 12184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/sphere8.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/sphere8.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3320445e5d26abfcd6cae2e81c426e08333722e4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/meshes/sphere8.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:114f2371259557ac546250c63690045bf9fb9dd89171c56124a2df9374d120fb
+size 46933
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/milk-visual.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/milk-visual.xml
new file mode 100644
index 0000000000000000000000000000000000000000..0b92a03fd62892b9341459c04f3cbf79b9a9bfdf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/milk-visual.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/milk.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/milk.xml
new file mode 100644
index 0000000000000000000000000000000000000000..c6a2404f1378582c1fc33920bf1da1893c54aadc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/milk.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/plate-with-hole.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/plate-with-hole.xml
new file mode 100644
index 0000000000000000000000000000000000000000..29ff23acd51beaf9f3d2698a14b79fb1d8fc182d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/plate-with-hole.xml
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/round-nut.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/round-nut.xml
new file mode 100644
index 0000000000000000000000000000000000000000..894547e5d75613f6d05643b911be1cc39d66c3c3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/round-nut.xml
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/square-nut.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/square-nut.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b33db2e8055d479ee1b3c15ee1534e0fdebbdc0f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/objects/square-nut.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..f2c28be24cbe8df8292234616147770959d01cc4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b76defe6cb672ae9738a457b76312452969d2b73b8495e34269008c051d715ac
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e5933822637b38d096fb3a90d72fbb893f982a66
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5cd395c754a8d50c124f00ae9424f6214e697519b3c5c0290be993fa33d8996e
+size 442742
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..de85f896b6886b97172f1a323887787c4fde93da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6fc0a1a1bee8949c2247ccf837883af5e31516bc54dede2d962243691c0c8c68
+size 260834
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..5db192ed8f17c6f826469e5df1b47e3ce0486ad0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb7fcba8800dde485fe84b064394506feeddfa589414aed49da8494c93a3b247
+size 423
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..359ae12d5006880c825906efad350bc3200ccc34
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f6aa86bf7b0a59941ee85f0de0c94a0779bb60c347d24b2f2da688e037af2f4
+size 305084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..949a5e84ba7069b847bd31e392bed2e907baf633
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/head/H1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96cbc1f7b8c5a7a89927d0a91cf7d3993df6ce43f31e96c6c7baeae0e34d072a
+size 174384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_elbow/E1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_elbow/E1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..1645322fdf43bba38c31773faf6c87816409bc51
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_elbow/E1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b82e27896cf4c64d445da77b52b6f1d643fc76fbddc1a54b67572e83ae852c20
+size 423
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_elbow/E1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_elbow/E1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9fa04217c4ceb2639791bda0c98344dceaf8de33
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_elbow/E1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bba3da6ef201294504170bf897fd8f2cb452d98fde01f306fb6bc45381f5a6b8
+size 1115988
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_elbow/E1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_elbow/E1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..9e4e21da8b4a51d3765062c352208fe22c579b8c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_elbow/E1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:68b033ec7e612536b07f6d095e6e43f74fa368e42f5db91555ab38ec504e8fd0
+size 196634
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_forearm/W1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_forearm/W1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..c1de083bd694a8d0ec52d991ebdf634e983ae9c3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_forearm/W1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e29f69975eabbf015f0fa798c5cb2dd362391cb8e90e25b2801515f42e2c1994
+size 427
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_forearm/W1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_forearm/W1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4e6e1acebb5fa97ce8aad5d59e167542bb59ae54
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_forearm/W1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f1bfab9203dc80a530129d03b90a9be6edeb9704d20fbbad9ccb63ceeb018c9
+size 1281812
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_forearm/W1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_forearm/W1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..9004f5a45481637bc66a90bad5d1ed16e3c2a966
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_forearm/W1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f5997c917acb65ba1c628bc0a60af101ccd1b2dcbf30bfd23f27ef3683c5ad9
+size 180184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_shoulder/S1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_shoulder/S1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..eee673b13547faba785a78d1844c1b51dfc23d7c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_shoulder/S1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da38617d1e9dbf7521070fce58c7c9539d8d202a9c7bda713e81414271e76e37
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_shoulder/S1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_shoulder/S1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..525b73c94eb2cf01c2054cde21fcbbfec9ead465
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_shoulder/S1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06ca546f11a057e3f90ca808c88b13531f1b0c85d5871394ce876f97c0017cce
+size 585520
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_shoulder/S1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_shoulder/S1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..6f85f47d825faa5f2a3eeae2fb4f91d4255a0677
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/lower_shoulder/S1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:679ae0707bfc3666943ffa94103f8b4134d924d0c86ff3e2f165f7f328988c2a
+size 168234
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..da780dd9687471c53d98d546a56496387f0626bd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af6e4618bf5b2698573adb222d79ec3e3a2eb58c31f9998a5537029b4eec5fc7
+size 1167
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..2a9dd77426259ad43fdc65c9ecd09e5a8513d798
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:657b6843909b4ecec9852949e46c3bcb50382e7354353ad4afea064f0bee6073
+size 7674162
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..c9d0b7bfda82e66971b63b266f54b5a1ff479019
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c9bd43749dca13a758d7a8aa1fc3dcccb20a50c3a9bd03089348265ee9a020c7
+size 464684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link_collision.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link_collision.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..ecf5425e98a66698505ea5616f2d11ba6c81a903
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link_collision.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bafb2488a40cdcd320933854e63846d52069484d2cf9ccb863160ff3b28a149a
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link_collision.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link_collision.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e9ae0148cf18dac3fe82c8dac5a85c198f446c6d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link_collision.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:03ae32a5f1087f4165d442574b2c71d56eaa9ee8cbe94c0c484db2938da5e229
+size 958444
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link_collision.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link_collision.stl
new file mode 100644
index 0000000000000000000000000000000000000000..369ffe19ab92bce8d47f31ffaddfff0fffd2d485
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/torso/base_link_collision.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2980b03a0e699f0d85dd4ce3074758ee7067de4b81409b5fa1db4e3078c8b58e
+size 458034
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_elbow/E0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_elbow/E0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..2fde2abe0c641872c06d0f031560801b37efd9a6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_elbow/E0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e38a7464ee3b85817f62dc556e77ab2f153149f16ff57759015335abd9514f1
+size 427
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_elbow/E0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_elbow/E0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b0b0ad0f891a44afcc739e0faabbcc849437dd00
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_elbow/E0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4033a198ba715cfa8cb1d25d1d419bd76438abde20371e60b55ca52afae815fb
+size 1249570
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_elbow/E0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_elbow/E0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..cf1c2ed223be70895fbff5408d136be544b6d9b2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_elbow/E0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2badd3a918969773800ed5e0f378676823c446d511383f957d1379fd75649b3c
+size 214984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_forearm/W0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_forearm/W0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..fed78c3c1e365d5f0ea8fee6140543a06aba6173
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_forearm/W0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25c3f09349ddee41ba636c9a31fe17672681e580ed7af9c5a65a3d0b1225a9ae
+size 807
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_forearm/W0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_forearm/W0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..968c9c04847e80e44ea1f8bf10a751f94439def4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_forearm/W0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d3130c527590a6c53c35a9110dac29bb979384e87aeda6433d785d62d3828c6
+size 2802095
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_forearm/W0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_forearm/W0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..299c3a47dbdcaf8c481858a41b6b6010887e6d0e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_forearm/W0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97d38ddc30cd28ed713977f69f991889cfd16574470265a8d575cbaa0c96d295
+size 344684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_shoulder/S0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_shoulder/S0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..2560aa9ccedfe243bd6a73f02177291b163514c0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_shoulder/S0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd4a453fc16c84b40ccf82e79c23958d93c552fc3ffcf2a8fccbdbe80c4a6d51
+size 427
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_shoulder/S0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_shoulder/S0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3b2722136933baaf50c7ebeb18cf9a76b279d495
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_shoulder/S0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1ff27deb7fcc317a7552a7f58de317e946e809dd3fe272a8d7699531d54e9ddb
+size 3708139
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_shoulder/S0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_shoulder/S0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..7da1a767d4fa90e51194a49ad94749b21bca0bd2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/upper_shoulder/S0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6a52d18783743e956c62dc5693381748b40a0b3c5f19f63ce59495bf13e74fc8
+size 402034
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/wrist/W2.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/wrist/W2.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..97ad78af8b0a865f0229a7b73226da093d869935
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/wrist/W2.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a92de9e7504631dfa839a43b739d720cbb0df1208eaa30d0650ab53b3f58236d
+size 617
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/wrist/W2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/wrist/W2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..dde0fee77f5861ce13374492f76d6f3ce3ac4054
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/wrist/W2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c18c29a228f1db62fed9742d0e3e009757f1ce98c5ff63a05084c1242730f24c
+size 1168229
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/wrist/W2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/wrist/W2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..f76da64469251765e12c8b07cc7e20172fc341d0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/meshes/wrist/W2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dbaeb449896a49d5bdee6d40bd6eaf26f43a3f4b156b2328ad35f4aba00b73b6
+size 148434
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/no_texture_robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/no_texture_robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..eac7d8e648b4a1e3d634bd4342d7175d6f642a4a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/no_texture_robot.xml
@@ -0,0 +1,194 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..f2c28be24cbe8df8292234616147770959d01cc4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b76defe6cb672ae9738a457b76312452969d2b73b8495e34269008c051d715ac
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e5933822637b38d096fb3a90d72fbb893f982a66
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5cd395c754a8d50c124f00ae9424f6214e697519b3c5c0290be993fa33d8996e
+size 442742
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H0/H0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H0/H0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a1ae6aa3f04b0de267b37a7462edaaaf4ad5b891
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H0/H0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75393f5e86ed251401de0ee80b310b9f664fd5fbeb7568624d7ae5c178c55cb4
+size 550064
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..5db192ed8f17c6f826469e5df1b47e3ce0486ad0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb7fcba8800dde485fe84b064394506feeddfa589414aed49da8494c93a3b247
+size 423
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..359ae12d5006880c825906efad350bc3200ccc34
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f6aa86bf7b0a59941ee85f0de0c94a0779bb60c347d24b2f2da688e037af2f4
+size 305084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1/H1_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1/H1_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6bdb9d49fdc6dacc3bcf0be01f0ce90ea6580974
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1/H1_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:698c713c81b743493c8c7d6148528b3e3c0745e605ef422a2f577cb2592328ff
+size 408153
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1/H1_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1/H1_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b6a79d2f56397af9d42be8e6a77b069c980e3c74
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/head/H1/H1_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:05f9a776f041fc9ff7459f0014770922804e883d1fb6646b6bfc50a83b37fe1e
+size 483
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..1645322fdf43bba38c31773faf6c87816409bc51
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b82e27896cf4c64d445da77b52b6f1d643fc76fbddc1a54b67572e83ae852c20
+size 423
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9fa04217c4ceb2639791bda0c98344dceaf8de33
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bba3da6ef201294504170bf897fd8f2cb452d98fde01f306fb6bc45381f5a6b8
+size 1115988
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1/E1_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1/E1_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..eb7d94622361dabdd7b042ce3eb2be3f1f9b12a8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1/E1_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6286acbf989bf2dab5c979e4f6429cfa92edc5dd4f45e4d6c4f7e3e3c47973e4
+size 1082756
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1/E1_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1/E1_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4becca61cbd400972034d95d700646ad02f59352
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_elbow/E1/E1_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:87b4090e97be61c3d97ea5201b35106f974312386bdd9fbb15ee5c2ea687b326
+size 364501
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..c1de083bd694a8d0ec52d991ebdf634e983ae9c3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e29f69975eabbf015f0fa798c5cb2dd362391cb8e90e25b2801515f42e2c1994
+size 427
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4e6e1acebb5fa97ce8aad5d59e167542bb59ae54
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f1bfab9203dc80a530129d03b90a9be6edeb9704d20fbbad9ccb63ceeb018c9
+size 1281812
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1/W1_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1/W1_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4ede251c7d2166ccddfc16201b3f2d41aae0f268
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1/W1_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a061afc4df56e4cc9ab48009f1edb6660f4ee0b4ade7242af7250d33a599c2f
+size 1289984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1/W1_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1/W1_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..53869fababf0d9c4f4e7a76b6a97ac8951fe11ec
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_forearm/W1/W1_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c79f09f34d9e0de50389fe3d9ccbed5fe7295cdbdae89d437446ae34ca8a9c7
+size 369676
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_shoulder/S1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_shoulder/S1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..eee673b13547faba785a78d1844c1b51dfc23d7c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_shoulder/S1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da38617d1e9dbf7521070fce58c7c9539d8d202a9c7bda713e81414271e76e37
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_shoulder/S1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_shoulder/S1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..525b73c94eb2cf01c2054cde21fcbbfec9ead465
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_shoulder/S1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06ca546f11a057e3f90ca808c88b13531f1b0c85d5871394ce876f97c0017cce
+size 585520
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_shoulder/S1/S1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_shoulder/S1/S1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..0fd1fb0b2e63fce46d9660116450df62756499e7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/lower_shoulder/S1/S1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1ea783562471b62e2266f8723d26f5c5075ad1f1de16baaa51965a8ec51b20fb
+size 749512
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..da780dd9687471c53d98d546a56496387f0626bd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af6e4618bf5b2698573adb222d79ec3e3a2eb58c31f9998a5537029b4eec5fc7
+size 1167
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..2a9dd77426259ad43fdc65c9ecd09e5a8513d798
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:657b6843909b4ecec9852949e46c3bcb50382e7354353ad4afea064f0bee6073
+size 7674162
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..0b480ed159015071cdfff98a372288b25b0e809e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f4d0e8d08a997c215656b39eaeac49810bed85cf6a777c07e87003861454577a
+size 436791
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..7e99d8300c1979629f09d3b156b307c0856765b5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29e572575d0a51fd38394eed90b90e444ad45dab5d2211160adcc82861bb4419
+size 917951
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a638ed4decaeb326248fe14d1762bb59da97a06b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cfab1eefa32e40cb3a15507745e20f2f822871c8006abfca4eaffad35b9cd784
+size 586928
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..bcd04f31bf6e318b7910f610284e8bbf89cb42d2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8d10810985fe13a711b67e6823507a0e1b9def877830f1595674c6dac4916cb
+size 643868
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4bbcb96d2478383c9d69379c869996859e2c7f13
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba6bb2b561734838120c24670341973acc4280de4b3e51cf6174526716fe761e
+size 7378939
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..8516199933334d41091909b9e7df9a7d98eaf217
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link/base_link_5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b907f1684b602fb4b5fcccf105c65e6e9cfd19d70ee007d6d6cb97e5b5b99e7
+size 76285
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link_collision.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link_collision.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..ecf5425e98a66698505ea5616f2d11ba6c81a903
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link_collision.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bafb2488a40cdcd320933854e63846d52069484d2cf9ccb863160ff3b28a149a
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link_collision.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link_collision.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e9ae0148cf18dac3fe82c8dac5a85c198f446c6d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link_collision.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:03ae32a5f1087f4165d442574b2c71d56eaa9ee8cbe94c0c484db2938da5e229
+size 958444
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link_collision/base_link_collision.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link_collision/base_link_collision.obj
new file mode 100644
index 0000000000000000000000000000000000000000..1ad2f9f3f2eb950a9c0153625da4d1fb0834ac09
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/torso/base_link_collision/base_link_collision.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4823c7870c2c4feebbd7cfe2363097dca57093f4477fff2b941834eeefd815c
+size 2367315
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..2fde2abe0c641872c06d0f031560801b37efd9a6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e38a7464ee3b85817f62dc556e77ab2f153149f16ff57759015335abd9514f1
+size 427
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b0b0ad0f891a44afcc739e0faabbcc849437dd00
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4033a198ba715cfa8cb1d25d1d419bd76438abde20371e60b55ca52afae815fb
+size 1249570
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0/E0_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0/E0_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..608ad6da6bcc1a01545371dc7d00dc05bcb73a87
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0/E0_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:00e7968c4995248e49e30d97c2cd216ac453fe37346fefb4bf71899481187014
+size 1227887
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0/E0_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0/E0_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4c27157e19d075f6340fb7c12c9e4e7274382a8b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_elbow/E0/E0_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffa8d3fcf0b39d04fbfdebfcc712e857984747bf74ed4b6b16bf4b1852b85667
+size 389970
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..fed78c3c1e365d5f0ea8fee6140543a06aba6173
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25c3f09349ddee41ba636c9a31fe17672681e580ed7af9c5a65a3d0b1225a9ae
+size 807
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..968c9c04847e80e44ea1f8bf10a751f94439def4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d3130c527590a6c53c35a9110dac29bb979384e87aeda6433d785d62d3828c6
+size 2802095
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..629aad022fca2bca9532342a515007237f845fb0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:caf539d497011da5f8e4d9c17acfd9f30a9a16c825b8db0553ed19c505a6a271
+size 202838
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..1a610fd515d1f7f3cd691e36d78dcf2dd00594c2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0ede1c72daf2a6eb428c98b628d5ed59e57694f967a38c88d434fd65e751bf3
+size 2859270
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5e40a54c19c243dca68ccf19d28d89145f7b4776
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:331ba0c22f54fceed1720cd759885464c8e2b4e75d2861d03e5d1d6c9a5a742d
+size 189386
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3318373728790dfe7a495b2ff8dec4425e2d818b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_forearm/W0/W0_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9fce43a04f3c5f571d649ebf74fd3edd2c3ce7b208794766413e848a83806efc
+size 366039
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..2560aa9ccedfe243bd6a73f02177291b163514c0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd4a453fc16c84b40ccf82e79c23958d93c552fc3ffcf2a8fccbdbe80c4a6d51
+size 427
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3b2722136933baaf50c7ebeb18cf9a76b279d495
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1ff27deb7fcc317a7552a7f58de317e946e809dd3fe272a8d7699531d54e9ddb
+size 3708139
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0/S0_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0/S0_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4caf5ff4dbebcb9b637dbcb80c2fe0739b38636d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0/S0_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a5a1599cf4d83e1f14678dd6ad4f3791c6bb0ef65c4b5ec7033fecc6b7932c
+size 4816582
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0/S0_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0/S0_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c28ac4ce19fdacc1c80e80718ee36e5ad0471233
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/upper_shoulder/S0/S0_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95c1432b2f06895dc8adbcbca34479a6660c4de15146d6dbad5a50bffbf1e66d
+size 14317
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..97ad78af8b0a865f0229a7b73226da093d869935
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a92de9e7504631dfa839a43b739d720cbb0df1208eaa30d0650ab53b3f58236d
+size 617
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..dde0fee77f5861ce13374492f76d6f3ce3ac4054
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c18c29a228f1db62fed9742d0e3e009757f1ce98c5ff63a05084c1242730f24c
+size 1168229
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2/W2_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2/W2_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c95816df15a9866e10e7a7d75c3273a4ca031e6d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2/W2_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f4efebb0c202790f3953bd2a9d0198d90a1e8b06f7869ef3c11a39bee0bff38a
+size 343268
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2/W2_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2/W2_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9bb90df69d417d8398fe5cf480fa78fb649c7ff2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2/W2_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfb43f7ff99def357f99c52cecbb710bd60b30a4b57d2acf96ef41cb6ce5dc4e
+size 3349619
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2/W2_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2/W2_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..21d84d2987ae997ee59a6a9b493df554542b3a73
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/obj_meshes/wrist/W2/W2_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1db9581a8d4fb2745ff8e87266c146872e5315bf49bfd00a395a377d589f6d3d
+size 138204
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..bd60be1feb5622a7d39bb889746c68540215be4a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/baxter/robot.xml
@@ -0,0 +1,296 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ffe82631ec082acafd07d1751071059059006892
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4568669d3a1b144c8f05c850ab7eaca71dbd3ad55ff93ea5a05b861faeb5f7f
+size 191420
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..bf36feb6f3c5cbfb95930a976b38d3bd6cc28bcc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c34032c80481e163cba7f18a105bd6379fbb6a829d362d834500e3486e41b710
+size 464934
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..15751124b37a00e80c3660db6f4c670c88525860
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7e40268ac1fb6da2804ef95f1e7489d2ee4c60750d8c3e2d71c75257531e594
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b695f7ea29d5cac3b5594733c14f0a7e473cc39b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ea0ded1fd32f422fa4d49ba8339ca519377e9b5b52a7787431a2ccd91d411ea
+size 3879140
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..065fdd69d999bcff6fee4b2b1058d4528f16ea7b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_0_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffb1fe84e189a71d44e181b9c1b536223065520ee584530b5bed387662e87f65
+size 1782284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4f217d591b489ba0ae2a0a27d1a5736f59d5743f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b1b0f52b84415bfd0ed3056e6e8fd73f3549b28116fe0cf9cd86b0bbbd9fcb26
+size 175775
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..e13ceb7e1ddd80199d26bbe40d1e974cf791329e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:abc2f56d710d0c235093772f67bf66cd70ee1dbaabb215eeeadf9ebe1af8c91e
+size 193684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..b9544679b3e5cb9fe731477cd85bb465fac89f52
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff59907b9e16f843734c07638cd6a5b7f4e6da3033e37252b2ff45f3404581a7
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..bca5c9715c4748f9fa8c406d2cb041a97b87799a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:baba099476580b0202ab43c194b204669b775876c91261a5f847e9e131ab125a
+size 986388
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..4d98a0934ee86146832ba9253940cc537a0be81d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_1_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:539a59f6200bdba9bf87f98b49c935085173145331438d388194febbba743895
+size 465984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5a152d229eab94565585620294cf867adbed6c72
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf01b42d94b646acfc520ff01ff62e393215a161a554d7345077bc88c90ae744
+size 89639
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..4ddb0184b5d252c13b17ae9d8bcf7678a6b39150
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6e1626424e9172161c4830eca09d9bd23f7c7571eab8e84e77fd28913834f40
+size 422184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..142f765e9c8a9fbe115f2e4d492187897521ab8d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7331b73a2eabfd947f83a52923ffe6e02de22631b8b7406fc92d6ff1fcc656f
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..25991748c4227e12b6135c5bd707ac5f9d888905
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34ad0041f3ccc351e0a947c115fb1100510f0ab76e50fea2d0f8435f5b5ed68e
+size 2988723
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..b1dfd0d14650b341801dfbecf5a217227f348423
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_2_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a0511257c13eb8eedbf1f5e02aa3c6132a249554c4c512295400a8374812051b
+size 1327984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..79859b7b22e9c0850bc3089831704fafafd667b9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e57f6c82d68a0c7ad336995ce37f5f0b6463d5f7b69ea9f3ab4bea4c4cfbdcc
+size 119675
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3.stl
new file mode 100644
index 0000000000000000000000000000000000000000..e593b3a8a6c6b72d8b85712be4c17eb7cbd6e7ce
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b8f8382ffccadcc05c73dcb0727b7217505ae8105adfe7a15e83173c2b862d4
+size 231084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..54b602506004ee2d356deaf462d11b32168285b4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a6d5160c4e8ba8f9ebbda46265f9c84a35adccd126b870b11fa0256024566a
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..502da3ca98c1d18ed7aba431147375dc52416809
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a77101fde75b0212127b72663a02a2321d8c29dfd2adfc92a024f04746f8e624
+size 1131490
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..4cad456384caa00fcd3352004545f8265e3c5b80
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_3_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4823603ca250881735890a0224b3b7bf81107d54c662b36b51abc8656f384b41
+size 528284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..f9ae49c4b3d79adb563cdc270716716d94f1760e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2abe2350293ad05fe8ef7f5d3ffdf8b5d62f99e264257e6f982ea5d2a960de3a
+size 96582
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4.stl
new file mode 100644
index 0000000000000000000000000000000000000000..10280d27b5be12ef60bbfda93a42c65133a63945
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dbf6666c08a1fea57463075f261f2ab3eebf45f764e1ad104ff076ec4f1f51c7
+size 423984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..3659b8b3dfa0f5bfc2e4f53fbd17c221af3fda68
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7fb6f71fd66d56883459f6d489d3f14c05ef6093a538146d7e9c35c5b5cd8ac4
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3f35bd16b815bd0f855b61d6add7e9f25a767119
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8e1f9bbaa11b24d73a1435e1d7c4cc926c25fe879925117216a31ec1845e8cb
+size 2999816
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ef8aa38e0de4cedf965368bfd67ac9c69fce8593
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_4_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89dacf4668f45ee51cd1f0a7cf027f180272c170ff1a1a735695a58dc2ed864b
+size 1327884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4a0b13e8614c6b47b72094b5f5ea1aa83b8308fa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:67dbd40d05edf4cc10ab38c5a9f84b83ed3902ea0f14b58a30ab01190ed2f1d6
+size 78828
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5.stl
new file mode 100644
index 0000000000000000000000000000000000000000..eaabe56277598cab20b746c6ef868bae5e69ebf6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:485d39c71a5528a4d504a8e81a94ef299c1fa5f51f6807abb0b4ea9dcfb68869
+size 220784
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..2404d80f7dc78a3e72bc71c016c91406a152359f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98daa94aed669f2db37529a42819c873c07a68d0849e50486e642f0a9cd64d08
+size 235
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..1ab6de93aab71d4fe857b3921cadc904e3e324ce
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a15e1fc99c265c208b05b1fb5df0b0dd1e87c2d1fac5f36671d0208ea7cc50f6
+size 1254618
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..d16f44c14dd704bf147da9575c77d34f26d53af0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_5_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7727674cec42ad82e6f6a73cb5055e83ee445a213ed12cee1e624aacbc707b7
+size 580684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5ba6413ba362388b2a34077299178239308a2a62
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad373b171e80ce3bb0b262d7baf57929357720896eb290ceadd7b4515391317d
+size 70813
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6.stl
new file mode 100644
index 0000000000000000000000000000000000000000..a0c931234bbab1cd40628567572cff3c855126e7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df125a61e2862e6716896d88d2b929c2daa91b26233df9ecb98d4b2a2c90b3e1
+size 484584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..cfcb76337509653df1d61a61db9a9e483ccaf28d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f1fb3e3100f1fbc6519277625d567e67d1eb18a491f5ab88f1a60abf03424f5
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e22e849001650d17b6000cf408b9e05bc9665f3e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:35b680097ddaf1d3db9920eb5fd00c6f76f01d2d632d9c5df32f1787359f1243
+size 3007254
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..01c21a59bcdab1baa06ca4bb83197b6f808030d2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_6_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:72f756210c1bb50bcaad00f443f3e6a8d68f946b5fbc69751ee70362ab1e0966
+size 1313384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ec9ba43f88b638d49b293bb97657ea59dc911d4f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4daa6315750ae17470e3b7c24a11edecd3676fd771dee9b0e3d4c44860360d28
+size 96043
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7.stl
new file mode 100644
index 0000000000000000000000000000000000000000..82828f6ee1a58b200f1b4757f2e59b823f1fe7c0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09bcb05bc0f47074fd64c3d9a0546fe013ace0939b0dcb7b8588b6af06914637
+size 653684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..3ce79682ac3b74963638bb7db25964697653684d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:646f866567c527bb730e65e31fee915beeeaa36631227f9aba6038f474ba9ae5
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..379bef2de9ef4760effce35ff3aba4adb4940ff9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f5e5ab7b6a099b0c63e4006bf1836919ee24fc703925d5699d04cd2b5c36cb10
+size 4662165
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..04c9486ff1b7be2b2edc80953b4b0c50794259e4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/link_7_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b6477fbbc24007398fa2972d6a3b15c3cfc6a6286dfef95732deef0bf8569e7
+size 2130884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/pedestal.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/pedestal.dae
new file mode 100644
index 0000000000000000000000000000000000000000..3d1b182cfdd00212645903b3980e1597d0cbd76e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/pedestal.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9a74be4ae114c9acafccf68e1e49d8fd815ec030012ce1b60bdcf9b30db49f5
+size 2734652
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/pedestal.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/pedestal.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..d14dd8ea3ce64fc27bbf003316b5c60fed2de6fe
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/pedestal.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:006ba359c617d76b9765b9beae494301c2816bae4726504b4aa5401a618df844
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/pedestal.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/pedestal.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5a74893fa216810483bafbba10b283ccdea9737e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/meshes/pedestal.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5665d1cbbf54e71bec781ee0af99684d2e516619817df250e43782fb377c723
+size 4143393
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..061dc5bb12f97b5f85d00086d49c8912a7d09b52
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/iiwa/robot.xml
@@ -0,0 +1,88 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..6b7de5ee329a0bb5937cfdb270b21b940fdc2dee
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7eb9e711e672934d2808c47b88c14da1e0976514691aa667fff6b1e034505ffa
+size 469658
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..7735c3c0ea73461539b6124ff8db833c37bf2d82
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfc9456293025112e1ac54bfcaf7002336ba67f9621aa5210dd83388b3c6ed17
+size 225
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..7c62633de88da80dd71e601dd3685c0dd434c3d8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ec4a0297285adc071a7989793ce38230b48378c14358f7a60178c2bd23b8a18
+size 703206
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..627b39b9b59daa6a54ad95efe2502481a3cd8701
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e3845b39cf479174a0405b4ae1830ca1fa445e5451ed9a26b7870f3d7773dc19
+size 321084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.dae
new file mode 100644
index 0000000000000000000000000000000000000000..0510b03c108882d5a537f7ead853e14acc698962
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ce56479380f41c816670cfbed4fad600329ad890e1d2b5d7beac891b0212f24
+size 467029
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..c0d7270263ba65dcfb5c13be911e075cfc4b4458
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d11b5b4f8e111a2a2d6d0351c89b77a39f34c2fe92afb1616ec279075b3ca42
+size 229
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..1878f3be721c58c833422b2d7c8ad1ae499487e5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a01b581909e2e1a803d110bfab7c89910315389adb01402ba2a9a597e139d917
+size 704116
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..902eab9217526b673ce6d7305099cabccfd9aa98
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/arm_half_2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8ec7a8deb28027d4c785611449ba83da2b61245c408646f7dbe15b297563143
+size 321084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.dae
new file mode 100644
index 0000000000000000000000000000000000000000..c4aca9a60ad20d883a652b7ebb2211e9e65797d8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84d6c46a5cf934f8d9abd669928948857ef8eec0d537ea226e9e44cd72e5133c
+size 1336047
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..1ff704558ac720191e6fc588fb66b57030eb9515
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:681c302b4d46ae0a54aca2befb0f70c0fd31d98c562bf67205a9a26eb361cca2
+size 233
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.obj
new file mode 100644
index 0000000000000000000000000000000000000000..51b9118c5054ca01de5cb9fd971de1e6fe9fd190
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:221f1e9fd309b3a97e781da619f55ecf3a0a1f07452b99745729f2e29a4fb3d4
+size 1886377
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.stl
new file mode 100644
index 0000000000000000000000000000000000000000..656b3894aa4edd0d936ab364149f1edd7d2a31f2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/base.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5889fe8e3c490dc41b8532c92deb7895dbf62ecd1544bd43942681efe0336b60
+size 853484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.dae
new file mode 100644
index 0000000000000000000000000000000000000000..5af3dea8b5fd5d41de6a5fa41d50b87a998dbb01
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4da6f9148f2b3fd0d4fdfa13129f1a1f2080277a07598540636c4b9eccabc170
+size 547752
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..3a2d845a1579b03b1807032e68e512b461a05d62
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97e4d6a811c2cb53a32b4fbc04ce0f4c6170533c99995d42305336a08bebc8e3
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a5050a9e3f28e0e90e5c3f00e975bc72b3042e36
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25bf077ff919ca1881d82465e5b2cc77cad96f248dc638911158eeb1382cd6dc
+size 761635
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ca64dc7649e275a77cbc5d38e9875f0695316172
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/forearm.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:48cdf606a4d84b329bdb37508cc2aac522312ad61c216c3dbc8106f7df022256
+size 350384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/pedestal.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/pedestal.dae
new file mode 100644
index 0000000000000000000000000000000000000000..3d1b182cfdd00212645903b3980e1597d0cbd76e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/pedestal.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9a74be4ae114c9acafccf68e1e49d8fd815ec030012ce1b60bdcf9b30db49f5
+size 2734652
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/pedestal.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/pedestal.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..cd396d468b9404d42e14b2272da5b7ddd73194b5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/pedestal.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdbc8ef1d184ea57a0bbb1b6e688b3058da58ac7c1d6981a697e419cb473916f
+size 232
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/pedestal.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/pedestal.obj
new file mode 100644
index 0000000000000000000000000000000000000000..7375daeb23defc94c51297f77995e5b9dd661122
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/pedestal.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:61cbf39abcbc7aa3f20642e370974ee05348d070521fadae11046e0ee39bde9b
+size 4140465
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.dae
new file mode 100644
index 0000000000000000000000000000000000000000..fb3f8fc7ef0614b81942e6cdf89cbd34e629eac6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ddd3db4dd2e64ef0673b90a0b2aead433c65ee81fa485a64c169356dcea2b952
+size 34042
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..a7f5b2b2f66cded0149cca162555a5ea086c706a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a40128e8b368c2481b6bc2eaf57fbe687bc540af04dd37e6b64246e12045ee9
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ce0a217c8803b2d9054fbdc140aace38643c6e17
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c72a733aacb06838a754281355c290b6db22baf875b9c7022b9abf61e40c6f42
+size 37985
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.stl
new file mode 100644
index 0000000000000000000000000000000000000000..96b9f64fb6e77f248a87334c38894542a2802635
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_big.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0715633b6d217880a8291843f4716de29faa317c7365be7d3c2c6b95ad47c58
+size 22684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.dae
new file mode 100644
index 0000000000000000000000000000000000000000..84a5951ee6c71476b32936883981c593580974d0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf4d3b83ed06116b66c6d33a7e2d5e2d2643b7a822147c44bfbba80b5b27a234
+size 35177
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..b16b231843c0fcb88dcf8f323084acfcc8fdcff9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:209bf888fdf75b7b8d88b4be11800f8f6f3e559c2ed02b6bb44e26348516fab1
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c60b2d412cdfb0f15a2ea0fa9bea6a2bc2dd390e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b5e76280ea5e11b18c34137ceea1af0a32c5a21a511cc9d94c1536c3fa6d85d
+size 38601
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.stl
new file mode 100644
index 0000000000000000000000000000000000000000..6ae5c8c3335e4d54343b30fe8fc9959d20fb1419
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/ring_small.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d60973e9d9ff8c6b05d49a120bab2fc8df42e7270d72e369808e9c1678db5eb
+size 22684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.dae
new file mode 100644
index 0000000000000000000000000000000000000000..ea76ca19b9792473706a84ea927588c15b51afae
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c94f20e293244a5cf55b724cc19573fb5f4dd3e7fa3ba6fb17f1582bc9bf9963
+size 643584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..436a8a53d3b24617aee67469d19c104d6da847a5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0f8dc3a3c8635ab855c42c05f8b3079f9ce64a5074962584645bd04d36a62edf
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9b91ba4816967e9f7b4584a95fbbe7286c266489
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41c5e331f1a1b09f39f488b8762cadd81e1289c8c4c990d1f899d90ecdaa75b5
+size 903959
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ff2984623ee23cbfb35a72fa861752a3592222bb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/shoulder.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:142067f7985c4fe9cdd94738a0fa8a7fa2d0ed713851d275c4491048b6157e99
+size 410184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..ed1fdc62723dd30490650e0e52ad572364d9a047
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:03bd008c3511cef0a42071b223f1089845f838df17975450919d852597032b7f
+size 566902
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..d5f1853eb4efba408083e3fcac3302307ceb28ef
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34a3959c4f5e5dab0d157e8336ccbd2f062c388fd75f7ed84c48a8f844b62d25
+size 229
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..03256674b09cbcdb4af71b97896a256da967150e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfa6353b552b3c02832229dbf520fbb23e16159878e170eba19c3d71fe94625a
+size 861304
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..62748fa31e8cd2a2d2f206bc1fdc9d559833af87
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e703075888eeded30c99b3daa83a98513ef3dfb6d57456b89f2f2240781bd62d
+size 387384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.dae
new file mode 100644
index 0000000000000000000000000000000000000000..c6515325880d144289e288b07640c4f60b5c65ef
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7d7924da2ca1eb6ba82c1415aa3171c52c807baf900222bf64597cf91516eeb
+size 568819
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..4614133957b197a54f8b88dc0b9780bae3916b3b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2ad264aa4161423d1c26cc9ddac03e2a1270783d94b69269dbeba91e80789908
+size 229
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..639662ed467aa9f03bf1d187601d2502ac31915e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20aab8b37fd65a6ecf05138dbce7941c0c7d07bb1c553dd830937d6bdce10cbe
+size 861905
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..eddd6fe13ea74d96541e18252780a34ae3d30007
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/meshes/wrist_spherical_2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7eafa8be6d2d91ca4592c1bb642c116d879b5c0ba952a0ec3dc9d111a86b5dbd
+size 387384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..8c82945a323a7d79eaceb67955d2d42ef3d99a53
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/jaco/robot.xml
@@ -0,0 +1,94 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/base_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/base_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..833c8d76591d348864282da750f553e3e1f6feb7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/base_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98838761774aba57f7d313c1d4c0d5625e6a89a178adc9799449713bbe7b0c6d
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/base_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/base_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e1f1647d12f3b54b48ce834df5d35954abcc1e49
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/base_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c825a124049f93948281b649c7835efe29759ec5bfaedf215f4ecec48adec2bf
+size 354327
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/base_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/base_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..5f6aff0d1c58b052b4464d41029bfdcdaa91dae0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/base_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e38517e653d58f4160cfdcee34c9dea983d8b3ecb8f2c147444f58235335260
+size 183884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_no_vision_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_no_vision_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..3ed77d1a7fe59181d2f777149a0192cfc3deaca4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_no_vision_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4696fb326cdec70ff50b429dcefbfc26ba26783adf45cd0d09575ec8e8530e9a
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_no_vision_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_no_vision_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ca7bae252e93e44209b5711654e8335b7b7f77c2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_no_vision_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:38cc80420d197714119a515cfc95173fedab5c16186a444b564ac3efe94994c7
+size 2903680
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_no_vision_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_no_vision_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..705d64f18e58ff342559845393b277df90ae3c00
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_no_vision_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b2ec7dafbf6a3c9a2cac717d18132a8a1efd9fcb8f2c40d9c12c8ece7323ec8
+size 1334384
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_with_vision_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_with_vision_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..0950fd51318b92e80add78a82e8964da275f4da7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_with_vision_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8eb1d69d678cddeac34e151a4f67f8e2a2f9db77668d31567d65beefd0df6c3
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_with_vision_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_with_vision_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..2954c910a89904ba02e0f2aea04297b918aeff1c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_with_vision_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df309dcea390ecf8e354cfb50b8b2f90efb481979375f243397fc88027e4e101
+size 2561538
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_with_vision_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_with_vision_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..72f81c50b040566e324f098cf263c7f9b84def47
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/bracelet_with_vision_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8bb05e420a88532b4eee3948a0047fb8e7f341c3e5cc971d54e97e2aaffcf878
+size 1164084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/end_effector_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/end_effector_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..cae2df9738c15c518c28b53c66fe721b3a76b62a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/end_effector_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6bef05ffe22bf1918c479eecec9be20ed7f98455e6ebcebdcc14e5fc88ce8546
+size 47
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/end_effector_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/end_effector_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3ee3febe0bee3d37a24f033d69fc9ea2149a0ca6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/end_effector_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:252e3687fe858a3a574ceb798f9e49727ed497820faf208e72e834faf7ef900f
+size 78
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/end_effector_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/end_effector_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..3e2754219dbba13475bf5179dc55c8aeb1ec4bf4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/end_effector_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b6fb58e61fa475939767d68a446f97f1bff02c0e5935a3ea8bb51e6515783d8
+size 80
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/forearm_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/forearm_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..9cec2eee6960ebfb966ac193a6287b10e763511f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/forearm_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:beb7e4b5cd33e7c65396133c8a2d2d17a42a028f07a3e904565906e7ef056451
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/forearm_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/forearm_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9c3eecc1639d613e04a1ba535715788d4bd8f754
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/forearm_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a34e75d356d679b1c033d8b18192960faf1e4fd804126b17b30c36e29e7994aa
+size 948560
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/forearm_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/forearm_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..0c616b5f7115ce35cf45991c428bf149a9a7593a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/forearm_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a00d96d2c7ec37e9fe4c5fb6d3fa2b4386dfd48fea5ac25632f4cb83b3505105
+size 456284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_1_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_1_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..1db586566e18400bfaeceeb18b056f016403b9ab
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_1_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19d1ee125ca4ba44781bed352ab6aa2ba9687ac4d57d37d13973e574d42f2463
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_1_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_1_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..183e0d74cf881f37edc3767260883270d7a321d7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_1_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1c186dd0f5afee25e06230cb00c30f70fcfbf869f43a21be3e1c4bb0e47c4a21
+size 1051661
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_1_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_1_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..00b636d56211e93fced7042e4373bd32436e4353
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_1_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8df461a71780d8e2d09104a507584424279d85173adf14867e359f7e58b3bba7
+size 514184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_2_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_2_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..f9c8fbae18a8189d1eb9623e790272906bafbeba
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_2_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dc5dfc5f663bfbd0dc05352d612f06a8a2c6d0e4cd65d0b460bfce3b2ef83346
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_2_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_2_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..29de7606f5cf69cc72d2caaba96c02b74c692a51
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_2_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2d67b050144df7a50a8ff5cdc3c26a557e2abfcce6845daf1e64334dcd316f95
+size 1005957
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_2_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_2_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..375f68608fd5cedb74e2a4ad0d1a0c3fb6ba1dbf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/half_arm_2_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a9f38de21648804b3ab7076ec73726031f2ade7c95eb2541d070553e025407b
+size 489184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/pedestal.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/pedestal.dae
new file mode 100644
index 0000000000000000000000000000000000000000..3d1b182cfdd00212645903b3980e1597d0cbd76e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/pedestal.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9a74be4ae114c9acafccf68e1e49d8fd815ec030012ce1b60bdcf9b30db49f5
+size 2734652
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/pedestal.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/pedestal.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..43b29f033881c222955aaf42bc6820c23bdf92a6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/pedestal.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:530777aefa4b4db928631ca9c2db96e5876b0685fb821b1bc8130dfd786dd0d6
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/pedestal.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/pedestal.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e4974e2e9d26fb45580a525207c1c9fec2082e5d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/pedestal.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b7b0bbe6fb0d171dc28df60700f82896cc83a01644a766ceeec0e4ad74e6f544
+size 4140469
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/shoulder_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/shoulder_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..5e1ed20951d24ab46a84983cd64a8324b7411c68
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/shoulder_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f8d77769f51177f3982efffa9949adf932524d128733cc2c41605ca28ba171b
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/shoulder_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/shoulder_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9cf1dee15a4d86166155580f1cf4c8ba19c6d14f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/shoulder_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:72c3edfcc1ad141dbf299509e5d8245de760d213de8f2440281da7e232a502cf
+size 1017935
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/shoulder_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/shoulder_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..71dbb2133e60e6fdc5aa136217d5ab9b163b44d5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/shoulder_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a98e232e5197f09e9cd520e7a450c21bcec74bc49e2a19ce7cf86c5dc25f2e18
+size 476984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_1_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_1_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..f74f823fa29fd1ae18c44cc314b56f1328e0b925
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_1_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5cd709ac78fc6db5917af4fc6a4ab7c6041b755bc07d89bb5e752294a09c78fb
+size 232
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_1_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_1_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..21ba0ca9a01455e9326216e2088466a6d622f2bc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_1_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d2ba21dae81886c0c9915534da2887d58079cdf8b6b6b669ef080ffe52bce7c6
+size 1155068
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_1_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_1_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..cc0413db97a49f0a4554ce9ba194a5251a65f4be
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_1_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b97a5e3b2de535344f8e8d4176d77f0cdb3455b235f35430fa0413208856e66
+size 546484
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_2_link.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_2_link.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..9cec2eee6960ebfb966ac193a6287b10e763511f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_2_link.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:beb7e4b5cd33e7c65396133c8a2d2d17a42a028f07a3e904565906e7ef056451
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_2_link.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_2_link.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6aa155dedb8ca4ddd2567b045fe4d085ccd5778b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_2_link.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7302bdaa889e4e485d3f1b6b2b46431dcc86daec256c31dbe58eb75770451fec
+size 1084389
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_2_link.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_2_link.stl
new file mode 100644
index 0000000000000000000000000000000000000000..aa941bfe51911b1d342252be944c3b57e88e2274
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/meshes/spherical_wrist_2_link.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:12271fb7d0ee78fb5642424cf5078be20f0476061e74203db964b674031ce879
+size 516984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..9b1a239369a61d45de01701c643822f3a447410f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/kinova3/robot.xml
@@ -0,0 +1,80 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/finger.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/finger.stl
new file mode 100644
index 0000000000000000000000000000000000000000..ef5e672efbb990561b36fcee2c15b2f61cf42065
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/finger.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2d07a740392f3b9b0816f65d64fff9927d3d57c897870fc4b6ff9c56fff3a0c8
+size 1684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/finger_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/finger_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..6c81071ad1b6f27bd97ad72839d9833f23ef439c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/finger_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0f5556f3ed25c5a9292e342d98c09885ddbe39c7096fd6cadf59b0cd93079fd7
+size 31284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/hand.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/hand.stl
new file mode 100644
index 0000000000000000000000000000000000000000..bb315217a60e27343b84a9d4e3a4686762c4fc8d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/hand.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94493e94f30fe940f2c8ca2f155c3bbe67bbff406d3edf5e261670d2f0f6e2ed
+size 10084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/hand_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/hand_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..aeee94b7236eb7b6702fd1c65beeb463824ee9d5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/hand_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:87a148e5a35a67f3a3e04104d1d63b7056f91b416e8aa0f37d5ecfe61e923fee
+size 353984
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2ea2fbee592e033a1cdf431400ad8c2ac4091248
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d9ceaa66bb3a734e3a32f2f737ae57a29e922f4a962ed77b9bb8d8e25cd33159
+size 1590896
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..bbe58384ff30b933eb8758429c4f5cbd970c1b50
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfc6d94330de8ddb005b311bfdba9f3b8e1aa7c256b71592ee7ff32cb9a9a5aa
+size 10084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2ea2fbee592e033a1cdf431400ad8c2ac4091248
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d9ceaa66bb3a734e3a32f2f737ae57a29e922f4a962ed77b9bb8d8e25cd33159
+size 1590896
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..d89dc9e808c4b05d05a13de7cce032beeae8ccd7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:12b22d2d2823d4c2f36095a505c3bd365200d4ae83d2b7d9021715439c175b68
+size 2316
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..306edb5b025453a5dd6a6df94a71e9d3632cccc3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:accee28b0c08b14177854785a17ec6c8b7d282b2319b80f6fb2e5a135c032c69
+size 2293426
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..b090d6f954c9127ee749f67f8781c899c595a13e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link0_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7047d9cc40f21d4e23ecba81d731521434e0cb30b278c318d5f12aba48105081
+size 1024234
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2030d60bcfb6e3f00c6287fbe8cad91be27118f0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c8b7b7c1217d620a811fc0ee52d1d1b0e1470de955e7453872aac3f15cf7c5e
+size 978415
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..16bc4cdd84f3dc82098e6e15f6e3c7dbcab73786
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d24e332dffccf260b91d05dde17c5998bb9559d37da8608a8ee5213d9661f603
+size 625884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2030d60bcfb6e3f00c6287fbe8cad91be27118f0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c8b7b7c1217d620a811fc0ee52d1d1b0e1470de955e7453872aac3f15cf7c5e
+size 978415
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..9e84da922e332d3cf4a86768085b08c715a17b37
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b3792c21291581dc04e309ce2494b5bcf8fffc8b13bb86dbbcc7e19cd3ad9e3
+size 238
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a63f971a288c3e75e7b9b27bf2e7bddcd871cc94
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ecffb03db29cbe2f7ffadd5947409c7409f4f948e02f919fd318d925021b028e
+size 1374032
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..16bc4cdd84f3dc82098e6e15f6e3c7dbcab73786
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link1_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d24e332dffccf260b91d05dde17c5998bb9559d37da8608a8ee5213d9661f603
+size 625884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2.dae
new file mode 100644
index 0000000000000000000000000000000000000000..64981bd82b803b79f46005dd56885b4ae01e9d87
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c44d0364f0030007e427106a4e842d835ca43902716cc46ee4f3342dab189e12
+size 998486
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..6ba548f4137d4ba09e7b0d9299fa631b27af1ea1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:370f7605a0fae3529db169ded50f52f171024aa792d4d773bc84197301f6a039
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..64981bd82b803b79f46005dd56885b4ae01e9d87
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c44d0364f0030007e427106a4e842d835ca43902716cc46ee4f3342dab189e12
+size 998486
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..f788da796a3ca791baa1bc1e32972c094d43f9cc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8970ae2bf72033dbcd5f9bec4f104fc5447188fa2b929623f4a281aef3aa8b38
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..2001ac745c81b3788c996a3350b11d7490c04ba3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6ada4575212260e77b96c9e5f3b42da0d8e5e6353eee57f8dcb7c6ee4ff559ab
+size 1386917
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..b15afeb957390112c44330a980bbbd414016259c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link2_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ca22e40bf92b26d88ae63a180867b2b6d226dac9204d8b749f1bd8337fdc852
+size 635884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3.dae
new file mode 100644
index 0000000000000000000000000000000000000000..23d6124df5e3d5696241441fbdf7dcfe53dbe150
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dab39a126153fb82f3650cca6de63a8e978f851aa5020a8e91b3d9d548dbba3d
+size 1099651
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3.stl
new file mode 100644
index 0000000000000000000000000000000000000000..7115ba0e92d33fd3a2e6e2087df980ae8b9a6730
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0a8d638b9349c6c0eefc4e888636ac4838c4b27170f18a51699321118af709c1
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..23d6124df5e3d5696241441fbdf7dcfe53dbe150
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dab39a126153fb82f3650cca6de63a8e978f851aa5020a8e91b3d9d548dbba3d
+size 1099651
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..d8bdf6226805a6f7f9c61a5eb0ba744fdf768efc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f019f3a20fddb775d6a6ea969d28f1a8c270bf0adcf448b9f7b6f4f9328881b
+size 856
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..93b8f9429192295ea55c8893653ae9e7cb8ccc29
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11f1504eba4f472585a8a3200d9aee420efc4196b31cb4284793a823ec007cd6
+size 1563997
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..cd555180b5df809da45f8a03dcd5b18aa248ab2a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link3_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9765165f24d0551b885bae0ce702448a711da22399c033c5e40bc56ad845b5f
+size 711734
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4.dae
new file mode 100644
index 0000000000000000000000000000000000000000..0ce1680db10d42992cb781fa23e3b5db43dea3ff
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e03d680e3a4a4555d673bcb8cb466e479f5cb069a5fc8a0b0f99c089c50fd63
+size 1145491
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4.stl
new file mode 100644
index 0000000000000000000000000000000000000000..88c6db70bf3c3b68bce08b9bcb5142050b1f9079
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0180ebb5772ec9840cb049750cffb29a9ddc90311752a16ea34757782ef9e48d
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..0ce1680db10d42992cb781fa23e3b5db43dea3ff
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e03d680e3a4a4555d673bcb8cb466e479f5cb069a5fc8a0b0f99c089c50fd63
+size 1145491
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..178626f13382ab50c1a1d557a9cf5cb006987aac
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f08734c55cc3fcbca529277e5b9ea8f41e544fef6bafa253a3f181b6a7d3589
+size 852
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..db79df246fa4d6b36b9ea5f80d49dc8120ea46fe
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:32989e80f160b70838d65a51dd3fd12fab36a4e0317e469ff547f03ecc81d0bf
+size 1615731
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..9a1c7ab2f1d9be60429570d322413590b2e539e0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link4_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75f852b23a9cd24735c59feacb12b4cb9ac018c2a1505a84b655144faefbdb30
+size 731134
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5.dae
new file mode 100644
index 0000000000000000000000000000000000000000..b6911ff709357cc25b27355fc36873d1d30f9cc1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0be76681192578a14d6ace89527e4ee418f7395e825e285125e05fd998d24e3e
+size 1438169
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5.stl
new file mode 100644
index 0000000000000000000000000000000000000000..5eaf5c8ec2155135ab9297d51e3dae6e5e280675
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd17e688c7870e722283525879643d53a74c0024d328b0e14b034b54c8b6c31a
+size 15084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..b6911ff709357cc25b27355fc36873d1d30f9cc1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0be76681192578a14d6ace89527e4ee418f7395e825e285125e05fd998d24e3e
+size 1438169
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..692dfb4a552dd3b5d3ca5a2c1db99645075c1733
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:58562759a863c7affa15368642a3bfa144e33ff2159ddd3a35f326b01bf8c105
+size 631
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a34bb50c9ebe5326196abbdccbc8b2cd5d22fe13
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9dc118edcfa779488d03c91f17e273360c96faf97ff82d145717d1da2d26f436
+size 2028820
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..0cc2f41f713f5a966d8af5030ceb3eec8d593f3c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link5_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:179de04c09da7ec80232f17265ed6b49d28aa4491a13aab679da107a74fff50e
+size 916434
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6.dae
new file mode 100644
index 0000000000000000000000000000000000000000..adac012b16351aecef432a28bd593edc0872a9ae
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed9c57432b079d55b9954775f2ddfe34e8b904f683949b8eb6314238f8afa46e
+size 1727767
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6.stl
new file mode 100644
index 0000000000000000000000000000000000000000..828ad3bd384b22ef734d8add0e50d6ae449dce9c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20b768e99a0e0440b5754dcca108016434e57937cc356acd9c352ccd3cb27f77
+size 10084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..adac012b16351aecef432a28bd593edc0872a9ae
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed9c57432b079d55b9954775f2ddfe34e8b904f683949b8eb6314238f8afa46e
+size 1727767
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..5db8e4f693373969e375f7f5a54acb012c177a37
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4d2f8656461177a741053ca584339195496d86c6e779a87ae8e68031fe0d968a
+size 3404
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9dd2a18e24b174200b79ee5e5c8111ee10818ab8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:907ec674c41933e28b21075e65ded297143d2691867eac15fd43721de670f18f
+size 2448144
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..83924e46ef92dcdcd777e0bafdb2f70cad66b224
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link6_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1a25b11b650d4a1c96b2299a020e38e3caa592f6e3f0f483bee64823495b1688
+size 1081084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7.dae
new file mode 100644
index 0000000000000000000000000000000000000000..b6d289bc5b7d51793fd2bf805695356eed8ae3ac
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71be614f734bd27b2d7dec3e8bb022251cbbfce38b0a12dbfc1b88bc0513822a
+size 935952
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7.stl
new file mode 100644
index 0000000000000000000000000000000000000000..2047756ec662f051af90fe61266998bf16e655fe
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92ac6afcf7574c034d3170d8a68e95ac9048ab9d0dd5bbd8311b86e551b9ab1c
+size 10084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.dae
new file mode 100644
index 0000000000000000000000000000000000000000..b6d289bc5b7d51793fd2bf805695356eed8ae3ac
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71be614f734bd27b2d7dec3e8bb022251cbbfce38b0a12dbfc1b88bc0513822a
+size 935952
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..0d634f2afed484fa68d6ba05c06d087af0b6cc4d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce742e1ce34ba609c457614439bd79c33fae7cd87fa9e4efd32aca646ce0d3e1
+size 1644
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..7f2ec2019fa98d7e880537172e4467d0a620532c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b80b32eec4c6272fb58114b65052f60c89e8d0e56307697f0b055cd9e3826d4c
+size 1272083
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..185bb136d06185b914b8f16765f216b86402b3e6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/meshes/link7_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d7ef70426828402935ff2bfac9f8be2f770a0c37dcb43e1e3d5e8c0c0afc5ac
+size 604184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/no_texture_robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/no_texture_robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..9b21fad5b33204620d74c4b82e1bc95dfe5301e8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/no_texture_robot.xml
@@ -0,0 +1,92 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/finger.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/finger.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..685184642fd6431fa28c373ad174b027da2f0825
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/finger.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:525d16f9a1338ebc2f3c062b30e462a279356aa4fe242617d39e05fa6926c27b
+size 481
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/finger.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/finger.obj
new file mode 100644
index 0000000000000000000000000000000000000000..755b123571ff7acc83822ee38fe559ffd9de7bff
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/finger.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6959401c281e6d5b4486f08072140b6fdb3461678face301ac448cec06322d3
+size 65359
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/hand.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/hand.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..5ddf352addf4066bbae00eb184c7bd52c3d372d2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/hand.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da2fb08fb2efec231cfd38b710e657b3ef04d3d3cf18082d3d043080dd6c79df
+size 1132
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/hand.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/hand.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9c8e9c103f64c4fff4c6286309a537a98d88dbe3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/hand.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3c95c6f99024842a8b9e73629dacc30ac5ee459feb0dbcc16df765fd27620584
+size 713204
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..12793d769dda4f4368d84ecf5383c31206fa0aed
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff032aa0aa3514c0820f8c07a676892fe5a7a85977ed3e4818d8ee8080cd8a2c
+size 358725
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..fa7b471c7e99d2e635c0c01c9b89084c1d30df27
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d033464e59f40a3096d2d6cf7853afb5bdb258bf19accefa24fa040ea73763e2
+size 122052
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_10.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_10.obj
new file mode 100644
index 0000000000000000000000000000000000000000..7c9f65b1293683e1909740b979c2d9d03b6c3a1e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_10.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5d879fddef79581242c8d28bc4dea19759712f5d83820b9a7c8bf519035d52d3
+size 407944
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_11.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_11.obj
new file mode 100644
index 0000000000000000000000000000000000000000..f169e1f08a4c8989959ef2291bef904934a4a734
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_11.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f47131b0f4a638b64065fa9d7fa0f6d196c7aa56ba717ce9d127d36a67338a5
+size 24799
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..95c352c141b99a86897c14078f9eb0158f072d63
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f8782b8a222da3981d0c0108930fb8dadf44639a16bd85bc89cd7401ee5a766
+size 731627
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c464bf0f548080012620c8246dddec7a8898abf1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3fc91c0cd10c4aef9a493699fa83457c225d0854b28d55ccb0b47fe38f412f28
+size 58037
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a141747c20557310d6dda40c009e27a7d9a08b1d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f23819a788f7ce26965e540cef7a9bf5970ed99da3c1b875a91b0e544511e39d
+size 256094
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..057b9c1e6d235eb8a129e8e967d2d4e5fcf066ed
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e8e5ff98258aaca852979fd62a39d929c1d726d9e6fa19f1acc030553cdf2ed3
+size 21416
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3f8d4e6fbb74f023c80dc4ee5673544b996ac45e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41be78cbdc6b0d401efe4fa630fe6924ee3312e294403aa35a4e18a70f99139e
+size 4495
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_7.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_7.obj
new file mode 100644
index 0000000000000000000000000000000000000000..8c15cd15e078f66486991ad420f87571ffa91a28
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_7.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75e3a2661f6366bb6c665282cd77f18b0da09ec1eaa8df29617fccf279970f6f
+size 34000
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_8.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_8.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b7d6621be9facc240620be13576685a5a43945aa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_8.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d751e2ad68ca735b0c3b0f8f11f7364026fe64166ebf4aab69446d4465c6129d
+size 3621712
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_9.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_9.obj
new file mode 100644
index 0000000000000000000000000000000000000000..2cfb2ffd4095aecf65c47c09ccf38045504e9be7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link0_vis/link0_vis_9.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19817601101caeea79826ffe057ad0481b97f60e80870ee768fe32b6855ab9d2
+size 125294
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..23112caa06b319e789a26399eedb048a78b260f8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c6e1b7fe3a382b872c986cfb3601e79b701afb7b4bfc72dcd6abd4b31f9d0f5
+size 262
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..7e840697010a583ad63477e55fe911f23762ddfb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0ec1bec9432a46f797e912cd8165798e62ebfd85c0807dce142bd6b9a768cd7
+size 1070650
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link1_vis/link1_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link1_vis/link1_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..207c80dd76bf5676fc10da91ac22c9471e241b5a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link1_vis/link1_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e6cffc2656cac8c8faa9b43e41cba312b35519c9cc4c5bddd7159f14df50f2d
+size 3676668
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link2.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link2.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..2f3dc43dd42d4fe485f1863772b4292ad67a14be
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link2.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7821edf4876534ddeb9cdc789453e6473e2f1b9d33747f036005788b83199da
+size 261
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..72fc313fa59501f85146cbeeff5d94ff56d8356b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1a41ada6d6cb8c51a8d0d19db9cadea101493842a52db75529e7ba0d4592904
+size 1084458
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link2_vis/link2_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link2_vis/link2_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5de65d643d34c77892f6c452d5bd9e755ef28c62
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link2_vis/link2_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42ada2805819bd9364e567e23b4c1b71b91d9ac02525598d823c14bf8d25dabd
+size 3713545
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..001972f5535b391822f2ba5ed4070cd04414e3cc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c2648eae04bf57f48100234ada9400156b1867d0fc2dbda2606bb786d6cfec51
+size 931
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..d0b342f4ef35a302f76953b3c2874dbaa37d90b5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:82f81c15c236a41918dade0241ce40e59c4a0ee7fb08985f3a934380a8f7f594
+size 1237175
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3d6091013a27ebaf994505530d2269f756de0544
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6a1f18e21663f2acea23752dcb79b2a8e8ae96ab286dd9d90de1d61c0dd87c0b
+size 3499294
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b7a35d54a5bff2036a6eb54e74c032895ab2c16a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73ad9e69c96283cc865424a64e38b53b7cd4385a0ff9834cae91a258ab92cb97
+size 81340
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..aaed2765c1c528454b22eedce1dc8a773b70f81c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9cd5a95e3599005bfc76e3700f1e3b3051efe07665b0ae6efbce75860a89e7f9
+size 98021
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..14cc7bfac589ce77334e02d80cf6e8c3ef045ce1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link3_vis/link3_vis_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:01f69ca50a96cceb1bc1a0fcb4de9a7fe2b5b11d1d2e2a5c445723acef7eb483
+size 571618
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..9f1a2e75d4e6635d47ec16a79bbc096e8bbc0355
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa5e228aca4026e34e69dc9eb80bed96cc7d81dae5595984c16b464de5ba5963
+size 925
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4539ff1edbaff4bf1fcf058aae80566b960307f6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:91f36ecba40ecce6c169ede6d152a6c8853ee886e2b145ef6c5865cc9d5516b8
+size 1272305
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..0f0ca00ede27baebc4b78d15224ebd1e7ec284d7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:76f42fd26edf4a6b78b055d3b48c4df59cebb93ffc36fc6e6e6c8d441dbec3c9
+size 99644
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..16b3a4591995cb42ad2ab2dc43608a590ec0256d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:85907c1c058ea1133cd6cf5a23535a1b77257b4852c753e08cbf87f79c4003c2
+size 3608854
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e54ceeb606068d3f57ddcdfed1986eb16b8c87e4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a019f7f862b782af509f649050ef45354ff73642742dea7525d8e1369a07fd94
+size 566926
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..d4ab20950c045a29aa22a7174aaa1bc35ea2fdcf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link4_vis/link4_vis_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b0fcc4417a0680860ee11af60df227ecc328b4b1c66908b9872dedfa8620d992
+size 80397
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..a68c251016268d4e78d167ca54448069a328e0c7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d44e731edacef0b66d7813468107877d05a12893f363ce3a0985c8305f404f3
+size 676
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ca493ec395907a490a279498629cff35f119ef63
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:10d084e6f728af7d291403f0359373c121e87785b177e5c6df8a7e1e88077f11
+size 1582192
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5_vis/link5_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5_vis/link5_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e19d2b16fd86bbf114e1901aa5d8f9593b19ea88
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5_vis/link5_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e30aa36ca950b2b57ba695423ee30da5da662b1a70b72d1533c5cd435b88355d
+size 907061
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5_vis/link5_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5_vis/link5_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..90a2e018b97d9bcd316cebd701c7308ea0fff767
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5_vis/link5_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d4089d9f2d26ef677ff42b27a2cb4ba9691a8fcd19ac665ac5203b23d22b876
+size 45788
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5_vis/link5_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5_vis/link5_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..d37b92cd0634bc31da3bc66b1992cd12b08621a2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link5_vis/link5_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c095257424ff95526d2ebab4f396833a2d77227e0d2b3bb899a0019d19c38c49
+size 4535777
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..f8f25910225c6c5c60ffc027f30f68c19e975104
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:78299752f3c28e87c19f5456e851423de647fbb065f8d88332c1f72ba5729b4c
+size 3552
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..37b6df71209edc724f2623bb8bc22a01288e69fb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7d6c29421e8066761c993a519474eb58f851885387cc2ad5a606f96eb428f059
+size 2236857
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3641b11521f80eb46eb76c1e00448d9378b36658
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ca3ca6f238a0ed1138bfd5d11982f610d96dfc98d2d663646f07fc4c2a3ea7b0
+size 181786
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..32e8999904fc52125ee50cfa72c591c3aec7cf38
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a96dc760e3b8b0f6ab4e098248e4019e67ec9fcf404c6ab944e216dfc752e590
+size 23299
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_10.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_10.obj
new file mode 100644
index 0000000000000000000000000000000000000000..87713887237e632206c961faffea0a829aa7b043
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_10.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:915c2eedef0d04869814126a96da1ff70d1fc89fc65ac80021084b00cad5d757
+size 462054
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_11.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_11.obj
new file mode 100644
index 0000000000000000000000000000000000000000..92397d1dcd2b8cdb7ea27573bfae4c54ff1a7f32
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_11.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6edb35baa656d0b647541ab9de4706c4d471ec831e732955f9f4d347f4bc02fb
+size 30999
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_12.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_12.obj
new file mode 100644
index 0000000000000000000000000000000000000000..69d264496143fa2800a40ca0c4204d961dcc07b3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_12.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5af3b6bdecc4de424eb0217627bf0f6185f126dbb0bc1a7fe39253f0a26f2907
+size 4214
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_13.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_13.obj
new file mode 100644
index 0000000000000000000000000000000000000000..00deb8bc958807e822e4583e01b894d4dd850685
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_13.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:21dfaf6cb5f27a8593be82cee6ca3023e3829d5f2be5e30af54b5955ea87a554
+size 4327
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_14.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_14.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a4bc5fbd96fe9091274dec5676bfa4ecf5af6121
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_14.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8aa89ae8b85fb54400771754d25d6813dfb41299821bc984f37498ac0d9cba14
+size 533553
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_15.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_15.obj
new file mode 100644
index 0000000000000000000000000000000000000000..82d05761e240e015e1cb32b38f0ba123377391a5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_15.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9c3bcfb59c80b0e9f67c62775ea7564f1f11f55faf94b971c7fa2ee436e91bc
+size 819438
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_16.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_16.obj
new file mode 100644
index 0000000000000000000000000000000000000000..24e704998f534c62e3d848c1c08f8221424bc8c8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_16.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e76221e7c530dde22117afc57d95d48ddf30a5e586e3bd99ba6777a30b50e58
+size 4284801
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..90b0095a7ccdf63d4af8bd4b64936008a7b64605
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bef85a944a58e7032a765cbb10445ac44450fe047f3fefe501e3a3ac90dd6caf
+size 10246
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..28c3e63fbb4bd495dbd7744e15b5b379783370c0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8f1f77b3978a16fcc8b4dd00af50431d1bca26576ef92f73d90d984dd188211
+size 15976
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6942f140fca63d79181948ba71d913bd8174ab45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0cf60228d09854089185c5bf9f23004603b3aef874b61b2b8dad7bdcc29178e1
+size 18377
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6fcb136c07d912e5e2095dbe7f85e5a9308b9c70
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfdcbfb610c052e2a4830fbc26a53ec99c0658ec3664a179c5471ff42d6a60de
+size 15690
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..71a321952d17098dca29d010868964159c9f5256
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b13bafd7d8cee84a07e5a0667da8d1e1d9467112415332f23ef79c582421f8ce
+size 16884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_7.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_7.obj
new file mode 100644
index 0000000000000000000000000000000000000000..99919ddd372f505a4965610a0fb7066c3a1e252a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_7.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52d8f9cac4326716f46fad9aa1314491d377b1b36a374258a97aae196b006e51
+size 2464
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_8.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_8.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6aa356f859758bb5dc518c2a8305edf45ad7273d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_8.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e586f0aaa641cdae0d372bd2fb689600cf13fc629a52ba4ad6076bfa7b0834f
+size 6323
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_9.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_9.obj
new file mode 100644
index 0000000000000000000000000000000000000000..7826b2dc9275567373bb710538e37ac15ff2f0cf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link6_vis/link6_vis_9.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a3caf4c3596ff2661d559ee291a9ac7f66a6d943ebfcc417f229de72962d0dc
+size 16367
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..44489a11f5f6f21d19be60998117bbef3457a5ca
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a43571b6feb5dbb0b6f5b3d0a06f51f356371d27a10fc580c985107db16f2ac
+size 1831721
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..fcb08857f1b2765c1f5e97f62f98b0676e786f99
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0fefd18ce20e90dac5ce7d181e676e6cf7504cca1f050c4d05c098169b417116
+size 162840
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..88bd05f1d05340a5b5f113019c68b33c54e0010f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44ae15f4a83bd36f4de3cce78e3f592c8a070ca1fd4539be3b1c58e21d8d8671
+size 280928
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..eaf9300e7d600d382b71d4e5e3e01d9b5286eedd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08c6308530928ae7234edd1297039aff02b0d84d8bbfec8d01c5ab104fc4209b
+size 164977
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..50e878ca36e0e8ade0dd1c4c866ca6f270c66c7b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53e0bffb56a0fe9ebab0f43dc997512287ab14d8f9b2d77f1baf7f0c96367f59
+size 114146
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..8f60a8f9422247b4f987f1eaf3ac7dbbfb360d86
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1fd161fb00e729732950770f20fb6e75c5a5ca1620bceac9d550c54f322d469d
+size 304295
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4c404437c9157bfe08cefdd9c7c41401a6cabb95
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cd4f2b39adfd4857f922f3104fdd4cef01390265a5f2582c73d97f39dbc1755b
+size 133885
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_7.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_7.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c26dcc772f9ba9e4bf7c2f9b3313dcd36967a20e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/obj_meshes/link7_vis/link7_vis_7.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f4f9e3c8532df4d1588e1c2ff82fae10facfb30a52329b7f595ee37d7bad6229
+size 1058189
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..dd30bb8d028642eef62018c05c70feb1b5db2350
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/panda/robot.xml
@@ -0,0 +1,252 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.dae
new file mode 100644
index 0000000000000000000000000000000000000000..79993c4d499be9416a86e6f5abdcb35845fb7ef4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f879eb086ce265c73fa16b6c39b9344be2545e293fabbc4cd37ae405f991101
+size 1000898
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..b1bcc9bf5d30a204abd9be1ee0ae6126bd8b0a73
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e96917a254c2ff4ce42dd5ee0c82bb33d632013f9aeb0493b7205fec5d68c2d5
+size 411
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.obj
new file mode 100644
index 0000000000000000000000000000000000000000..256cb2dab250dd39cad70c40308832a1a3c24097
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aaeb52c143f716686b9ff094c41a2af2e761f8c6b0e633d3bb9500202f1a72c7
+size 1423343
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.stl
new file mode 100644
index 0000000000000000000000000000000000000000..83382c3684a5416298a7889732f86b57a5635efd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/base.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97b3997a278b7d7be42142f49f435a4a0d7856736b943bfb3590dc43210055f6
+size 264934
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.dae
new file mode 100644
index 0000000000000000000000000000000000000000..93bf692852e762e132bf998061910aee05f2cede
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d367f8c3b05f0a9261f1dca2e9040f4ae87908f343448c7ab8ddd51f74884ec
+size 1389340
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..3886714b14a4116ac77b23f3b91ff18403fa2ea1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:49c9ecac411073f585ff5fea5ca6f29e90bbdfe4254e63b7bfdb7afb176d5593
+size 1828
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.obj
new file mode 100644
index 0000000000000000000000000000000000000000..95de473d78781dadf9bd9acf14d8bce2691ed872
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e98e325ac05841d7b3e2fca1cab5d7520b4ff8de51532cf3f615de4f4d03bd6
+size 1966591
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.stl
new file mode 100644
index 0000000000000000000000000000000000000000..3fbad3d5b002cbdf5046ec679ea9a9f32aa780db
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/head.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e2d5dd7e7988417897c1b3c2b366ad9dc53e0c096e0a4db810f5b6f08e1385a9
+size 276234
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.dae
new file mode 100644
index 0000000000000000000000000000000000000000..8f39b917258adab2e893e8ea9fef2abb8e0f1bb6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7d5ef27f6f53e5ab78ee0791d28f47099753ab5966132f67339e3514346eefd8
+size 3760739
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..599be372537b75781e5f7decfa7906937744c03b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:264852472aeb2f441259a9e09dee42214c5b2b2b92fef9e6d6fd8b7d2dd709e5
+size 1321
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e3a08865ffdb3a3d0e8e43206a205d6e77e79e25
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e0871515535b0b4b60f31ca114470fc767522ec3e0dc4f987e7db9d724d1dd9
+size 5578526
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.stl
new file mode 100644
index 0000000000000000000000000000000000000000..263bfce5dbb757036183d9c849615e4508891ba9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l0.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f4a7f56bcf4cbfbb72414acd61956ccb8db88ca5ec4074ff626a44cf41a18c1
+size 675584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..253e63fd21bb6796d2357984078aca396175869b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f500967d2a91c6321a453dc3a5949013729354d04c8de4a8c20116e61918408e
+size 473148
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..100f24e77c21881f7d8226f8d147fd4e013190ce
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d10328be170934db4f575c103d6d65dbf954445ca649690d503797927306767d
+size 593
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..8d1e02249686849b4c781be8dbdf465710b1c266
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e551d2b6633fa879fe4d4a69d584922205a44ad1ad9a3899736fcc8ae5d8edc
+size 696392
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..964ea6cdf273aa59cb968bec8b71889df8bc5727
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6d8ba089c3da8a4a40e176a13928e3a39cfffa2fdc83311ac2f6b59035ab6d0
+size 511884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.dae
new file mode 100644
index 0000000000000000000000000000000000000000..b83494d7fb80655f2e0d5292b1eea479e1b74ca2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a56dc6354f42d850b35d246266de67c0fb22eb17576e19712340f5909116aed
+size 654737
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..e5c3fec6e71833822eea82eeb9ac9b6a6a2d0233
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfdece18474cb2479a0ae487126743f4a2cae72f1f22e3de8d1a8129089ad71a
+size 957
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..97946cdb06cac7b6bf8e86dc9b06939d7a69a93f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:974a66e72ae987c1cd7f47fefd4756646e298739a7c879244b2c3ff0533fb19c
+size 953303
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..3ef2de15c32d518d145e0d76a78894fd17d563dd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a0d2ccf5668f737409d0e7fc2578b9e24660298d09787852494a7adae8c58b1
+size 133734
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2a6630f83b9739810d4bc5d6506ceefaa28a5587
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84f67d46fcc4acafd8fd7bd620ef6354b884d3f96eb4d09e792b40ce9ead0669
+size 618017
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..144315ce398fb6dc49e14b7e1e6a88972196d5b7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c17f048659d1e2e4929c4e958b60a54353e7f8c70f51441d30c19c919c4f9cec
+size 775
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..2e5a528e05e888bca43d8e9760df7ecee4751c1e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52c5c2a65a874a400085a32033397963f68794cbb906780c9bbb6fe52e3f5937
+size 896221
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.stl
new file mode 100644
index 0000000000000000000000000000000000000000..f0d4d108e88ba1629d943451db1d68916e810bf6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l3.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9faa1f689135fcda53c50b18f9a5c55bc718bb30790395a20821e96b87a174d0
+size 160034
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.dae
new file mode 100644
index 0000000000000000000000000000000000000000..08e9bafdf91240f1c1213757797b5d33b63e0a91
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b1742337c74de3ec0841eeb651291a09dadd6cbe84e2b1f26a1c4a99272a1c0
+size 4922491
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..70f5f601a4c2aba96d2b1342836ef1662ed106d9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:05f1cc50983886f4963f46ff05852a5553eba76f44d7e80bdc8ffcd9bfc634b1
+size 1503
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..0ea231d268aa577cefdc6b98566ed05715cd7f56
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5ac9d58aa680a044b59980b81034acc0115c245e8ba4f5902455994e478c6583
+size 7258890
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.stl
new file mode 100644
index 0000000000000000000000000000000000000000..a7d307ae67fac5efc496faebd8050ebbf73dacaa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l4.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9f7032958ae8feef741a4a073f59f3cf9f8f491d93505d0bc9047269cb2e1f2
+size 208284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.dae
new file mode 100644
index 0000000000000000000000000000000000000000..6002ab4f2b52ba188f44ab1f786b71643d60d51d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8468804ab641adc86c10d3d21f22d6e025189bbc8f3a18ba8a7e73f88f49b0e
+size 2120107
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..bd633c15237cd4c54bdb5dcb2ac6bda0225d331c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:35afdc0890972e202b4b329de72ff70cc9b878ed5e3379600edf4f71f27faae4
+size 957
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e7d0ca6e3d3c38efa2c925cad241fea20a5a4383
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e0d0f74383ee20bb4dbbab14dc21233a68458d842c7d972c7f68dc27ce55629
+size 2904791
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.stl
new file mode 100644
index 0000000000000000000000000000000000000000..1ca0d1a4164f61bde4838fef7f86312150c5bc45
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l5.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5078849cf3e1ae7790a6014e1ecb51eda84b44d2d91206783bb0b1fd1740e9ff
+size 176534
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.dae
new file mode 100644
index 0000000000000000000000000000000000000000..0a170130340091219653ef8fc94307b7798669ac
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:096155b75cd14bdb9e921c990f358844c9212178f32034a3fa92ef299039b7dc
+size 1990228
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..7c1910e1520fadf6e38711fe4d16876b9d1f0055
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2450351ee19839756c275f770bcc2edf7b476b42eb74b334bf679b85962d7fa2
+size 1139
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ebf7bb4ae33f5338b0248d0674e1859dc942fa12
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e5207527499e30c9fedec5cd3983792c59b2fc932304c1a38b352fe7ee2c615
+size 2838795
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.stl
new file mode 100644
index 0000000000000000000000000000000000000000..84627fc23ba0e3bbcb09e24872cd481bfd4a244d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/meshes/l6.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e55c25b435b2e038f21542cb9315745656fe248cfc1d388364acd7e73e4a28e9
+size 183034
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/no_texture_robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/no_texture_robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..cdf2d5b72f0004f781a4e56da4cecad7a0b2ce3a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/no_texture_robot.xml
@@ -0,0 +1,121 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..b1bcc9bf5d30a204abd9be1ee0ae6126bd8b0a73
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e96917a254c2ff4ce42dd5ee0c82bb33d632013f9aeb0493b7205fec5d68c2d5
+size 411
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base.obj
new file mode 100644
index 0000000000000000000000000000000000000000..256cb2dab250dd39cad70c40308832a1a3c24097
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aaeb52c143f716686b9ff094c41a2af2e761f8c6b0e633d3bb9500202f1a72c7
+size 1423343
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base/base_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base/base_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9342e2de393daca72cf1ede392e7e7c93ceccb56
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base/base_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e547e145646f227341c5ece469598c3c028f6c344cd8a117b17c3989f7a2195
+size 24137
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base/base_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base/base_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3c7bc03d3ed2beb9b360e79249077feb008d5a51
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/base/base_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a29a1fc7532d40fe1ce6c50f78853f756f6f27648bde00b765918ffe047b78a1
+size 1902755
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..3886714b14a4116ac77b23f3b91ff18403fa2ea1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:49c9ecac411073f585ff5fea5ca6f29e90bbdfe4254e63b7bfdb7afb176d5593
+size 1828
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head.obj
new file mode 100644
index 0000000000000000000000000000000000000000..95de473d78781dadf9bd9acf14d8bce2691ed872
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e98e325ac05841d7b3e2fca1cab5d7520b4ff8de51532cf3f615de4f4d03bd6
+size 1966591
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..19b51b856a58b417ba687e85de4a2a1b688a7dd8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f88d7b31887c4e60b0d63fc573ddbb7dc4b11f3553bf5850f0c86e995542cec
+size 17256
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c40664a2ae8ed0e374ca2c3352c702dc1aa6bd2c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:938fbcad656d9d92acf430754e4b86a3cf9da07cea1e8147fc3bcd96c2a8321b
+size 7308
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9a59ea2bba0a0e25d8af539193878a90c734268b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e73560e2ce46ac81db50f40772c11d7efa529914b370cc3f897fb1985abefae6
+size 26789
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..da6563bea56e6618bfa9f9a423aaa24d3cc81b5b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfa20d84f02301c05e5401182601789df5f352c1a96ce74f56430bb724a11282
+size 49795
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3d3097328f2b79caba6b42447f7ebae33f081fd3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c87abdbba03512c99de28290363b253e333f565135caac7c92365f0aec6365b4
+size 67186
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6b7239ec0c0b43a769b311434940f86be62e29dd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:36053842c61f0977e5cc89c93621b5abf69807a0a081b453311ef1247483e063
+size 523234
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..797f4e954f3325048b7038395ac68de26b9e7180
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d690b8cf4d4cd0e58cf1f6374cd00e473a136ba10c6713ed132899be748f273b
+size 690999
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_7.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_7.obj
new file mode 100644
index 0000000000000000000000000000000000000000..d584b8bc44a41ba7664eeeb6f765701a886d00a6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_7.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3dc7d9fcbe86ebb31b84e18c73a42bb0049ae3c31fd660938aadfc4118eb7d29
+size 67662
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_8.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_8.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5ceca6d4496df69034aa9fd8fd5618a3152bc4a0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_8.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e087cddb34d9b9708c6733d013d46b7e103ce1ce552c689cf3ee8170d13d317
+size 1249668
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_9.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_9.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5b966e2bb7c5e8fbfe5493d65175ad7ca37c5005
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/head/head_9.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e55dbbcef11f2a551d23144d41a09fb61471d131f5605f2d9772f72d550f9a5e
+size 2966
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..599be372537b75781e5f7decfa7906937744c03b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:264852472aeb2f441259a9e09dee42214c5b2b2b92fef9e6d6fd8b7d2dd709e5
+size 1321
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e3a08865ffdb3a3d0e8e43206a205d6e77e79e25
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e0871515535b0b4b60f31ca114470fc767522ec3e0dc4f987e7db9d724d1dd9
+size 5578526
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..98086c432394d3c04c0a42342ea1c8a197069085
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5ccce2f6632ad8684991034e90d8ea06741e577802c522d41fb6f6f09735ab5
+size 983
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b1e30caeb9fc3747d337bcdebd02051c95bc1326
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c2f48a54ac9445b0540a5d811c3a33f6eaa678866e50dbc1dbe971344c7dd6c
+size 10332
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..617d81a94092ede96a00ede0978190b8301c197d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b423c02e61c90648623f1884fbb78e21562ff6975e682f29725850d44b23839b
+size 1510147
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ff81ffde5cffafb658a42c69ed1c012101c77278
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b605855763d12255ea647caabfa40a0c13985fc8ae16507572737072b5fb3ff7
+size 19377
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..75d398fa59b508c4631f44a15074f992f96226ad
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:096c49f356dc63fe1b1b2a7a5157d9d32f33fe8735bf6844fed9ed706abcbc64
+size 2230
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6ad39a219b2b6db974fabfa8c25f80c5aeab59e0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa16edffa27981e6c8b9b39178b6e8ab694bf349aa9f4a2366d296db5b002aca
+size 2630321
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..a5c58d942a68ff57db57f4bfed24bb08f02f498e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l0/l0_6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1af4d3407a79dcaebee0fec864cc140154f6b19d784a79a6d60b514277c2851d
+size 3085040
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..100f24e77c21881f7d8226f8d147fd4e013190ce
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d10328be170934db4f575c103d6d65dbf954445ca649690d503797927306767d
+size 593
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..8d1e02249686849b4c781be8dbdf465710b1c266
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e551d2b6633fa879fe4d4a69d584922205a44ad1ad9a3899736fcc8ae5d8edc
+size 696392
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1/l1_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1/l1_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..664929843f0c3a1b3a13ce3b32ed3017285af627
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1/l1_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffab362a1d710a8152644a848c877f8f63bdeb3754b6e6c6a4104ef1c1384db7
+size 243190
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1/l1_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1/l1_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..21bb8380543343ae4d986d940008d7072a8bee52
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1/l1_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c531ddfbd7271e198ff7ed5719cb60d123f713b7b930c04a3d9179b01bd25c46
+size 485196
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1/l1_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1/l1_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..bcc9435f15310afae41e263ee4502a4452b80faf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l1/l1_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e87f5d62e04421430e6698473ba734c07db99d5c41cdb27054d5b5c80c45126
+size 174477
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..e5c3fec6e71833822eea82eeb9ac9b6a6a2d0233
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfdece18474cb2479a0ae487126743f4a2cae72f1f22e3de8d1a8129089ad71a
+size 957
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..97946cdb06cac7b6bf8e86dc9b06939d7a69a93f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:974a66e72ae987c1cd7f47fefd4756646e298739a7c879244b2c3ff0533fb19c
+size 953303
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..61c1ef1fd630555bc32f1a8a642287ee677caf4e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3b5ac9fb3b46214d19d4f57c8ba3e7c483ddeff5f1a66326d32fd26f60014a5b
+size 101609
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5f15ba505defb98e7fb48de5ead684d4325ed081
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a895e0ba190688fd55d875bb6c8f27a8d8011e5c98da13e06742a1c350bf6c2
+size 122053
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..0863e7307ea569cd31c74fdd9d166dcf8c9fcae7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ef1e27a6245d74e83a28d1a36f102dc2aa68e4cad93a82caa93687cb70ba8dbb
+size 22747
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..1253584681058fbd0afe531214f250e17147207f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c7ad30b0f4b01ab8eddef1b78046d4d91baeae2b59cc596baf5bcf40d091a63f
+size 379478
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..0fada19a737cc03e382aa418e0afbeec448f36a9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l2/l2_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6fa7f82a997f91817d7e7988c173d35b3f24564cf3bc9c93fe67372afd1dd662
+size 624028
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..144315ce398fb6dc49e14b7e1e6a88972196d5b7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c17f048659d1e2e4929c4e958b60a54353e7f8c70f51441d30c19c919c4f9cec
+size 775
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..2e5a528e05e888bca43d8e9760df7ecee4751c1e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52c5c2a65a874a400085a32033397963f68794cbb906780c9bbb6fe52e3f5937
+size 896221
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..f726be80bb56e14aa9410793da5096f9a9178864
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a64c720306984eca6e57e89ee6368a4be02d80791aa89885344ce634c4b73eec
+size 20937
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..9a82cd3ef56b21e8779b8189260ff963fa2a4fae
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:49d71ef4b973cd88b794666cf6179fd662f26c8da5cb0adf54fe26707ffb96fd
+size 389695
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3132ece57f7e1b4f59fa1cda77cb0073da8eb22e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f4b1336a4d3e96e93f3d7fb2678a4f5f30d261890ba6e9b1bc61231171dbc70
+size 99440
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..91f350d30e81788d8eb346b2eccb2c17c77de6a0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l3/l3_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:61e29bac1d16bea9ab7ffbf9b493b1a597c2b14394b92c982bf7490ca64330b6
+size 671873
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..70f5f601a4c2aba96d2b1342836ef1662ed106d9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:05f1cc50983886f4963f46ff05852a5553eba76f44d7e80bdc8ffcd9bfc634b1
+size 1503
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..0ea231d268aa577cefdc6b98566ed05715cd7f56
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5ac9d58aa680a044b59980b81034acc0115c245e8ba4f5902455994e478c6583
+size 7258890
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..eb255bb57ab1a4b84ee4914a89ca8b424ffd3f08
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:934bb948f450ec9237e15af67331e6c15daaa45892dcad88d15d104517823fdc
+size 11477
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..8f339cd5773b510563a77e8856420869e418b22a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2c672c27bfe7690396af9e54d57ceda3fc89d6f08a69324f7ed38223cac15ee4
+size 107524
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6e48c759323aed20ae02db9d1480f909f854ac87
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ccf2480ddb77911343d9e59f5cad9d4838f32631de126122bc5c22399432e57e
+size 633606
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..1f29511a5354830e12af68a2afb1def0a1147051
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0fe6f9f7dac9085445820563a5baa4a836cd9ae99ccbb02e683cf7806915ca1
+size 995593
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..426ba76f25cb4ee31dbd2049853c47c5f437006b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3977e1b8b7309bb1eb11b375424d45c367d3130ce1418675523b6c57fd7ae24a
+size 644926
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..68f3b265c20660616c8ddbb1876604de5e2a5e78
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a398b7fb72ff3cd40c5fe0ccf780ed086f1e4d9697036d9fe8ef5dba2d12a3c1
+size 865576
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5fda2d361d28ecfc38d051b7197aad3e5abc0da0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:baee574065439546b0d5d5405d8920564589791a095040cd300136f026f38373
+size 1515936
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_7.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_7.obj
new file mode 100644
index 0000000000000000000000000000000000000000..598a52cf5631eea09860f77c0f9911315c418a92
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l4/l4_7.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd86085b7e95f9321f4571fafdb1b3c40e2a4efb5b056d6d97e98b293a4733f3
+size 4653975
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..bd633c15237cd4c54bdb5dcb2ac6bda0225d331c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:35afdc0890972e202b4b329de72ff70cc9b878ed5e3379600edf4f71f27faae4
+size 957
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e7d0ca6e3d3c38efa2c925cad241fea20a5a4383
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e0d0f74383ee20bb4dbbab14dc21233a68458d842c7d972c7f68dc27ce55629
+size 2904791
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..7d9819671ac1f189127d7a90746e7c7a0e40fa6c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:752b0b4844e7c9f7abd82030d87fe14703cd175d13236cb5878c27ff91f90ad7
+size 28915
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..bfcb392bcb8c2b7cdf244d2e28c68327bb510ed0
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f27709c10f5330e989c235fb943ad88314eb3a39d58f4edbb71391fe67476d26
+size 13227
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..1de25c0a42b3ed9ef0f49f367c749644b138fd12
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b0bfab2eb17fee7e743c4c9ce819abab2681a11c3f34b9a1960261d374a8530
+size 227206
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5314a137e24676c14cedfafd58e063cb2adaae5c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c7f821e5262fce470a2789b1fa16b97217b94a5fb8800237065d2a2c529105d
+size 2903969
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6b5d58dc8987c559c0509f2dc6fb9a6baba55d40
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l5/l5_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ea77ab10772e5530c52205f850785490145eefe95fff213265a814add4538a0
+size 716571
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..7c1910e1520fadf6e38711fe4d16876b9d1f0055
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2450351ee19839756c275f770bcc2edf7b476b42eb74b334bf679b85962d7fa2
+size 1139
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ebf7bb4ae33f5338b0248d0674e1859dc942fa12
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e5207527499e30c9fedec5cd3983792c59b2fc932304c1a38b352fe7ee2c615
+size 2838795
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..234f0c11741af840626ce696dfc4fbbe08d6f06c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:033d560b31cf1ebfe4cdbc6de3c7f2c23ec8c547eb15533237dc8582ae5d7ee4
+size 9443
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e361b280d20aa5ffb8c6d05ba275a524183d9bf7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea00b3efde74ab54e978e9940bf0f9036d69f7b792e71004238da8a0f5542584
+size 24130
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6f8259eaeb779183701c3a64858bf6949bce5dca
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9eea7ea5457978d3eeba790d923878dceaa118fec00c318292f6421ef3e01380
+size 4713
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4dbe13c95c9571c942399f94359d6e413f704d41
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a335fb9fe5b133e89912c5b0dad19f84642543fed8f7af3635412f6307684f6c
+size 80416
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_4.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_4.obj
new file mode 100644
index 0000000000000000000000000000000000000000..e8d6f9f96194032504770309d3b0c39c37f0f80e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_4.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb7c8f765540a84023f2e4d3b319623c1a2c1bcbac0309a921226f419f505cb5
+size 127507
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_5.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..01263012338d344789f4343c14c63c5014df6b70
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/obj_meshes/l6/l6_5.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07488040bedbdbfd6aff80384fc287be45229e40614512501781860ce21efdc3
+size 3783211
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..ad115e2c5fbc12235da3914491cd86c348a72c35
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/sawyer/robot.xml
@@ -0,0 +1,281 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base.dae
new file mode 100644
index 0000000000000000000000000000000000000000..858cb12e429b814657e2cef0309956e930c18dd7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7d37bb19bd062ca80ea00dfcd758b1145455c5cfad87d41756b101ac5b2a8a4
+size 358055
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base.stl
new file mode 100644
index 0000000000000000000000000000000000000000..17532f8b2337d284b57814b0b04df1cad492e7f1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a145e0d46f2130afdf2a2e8825a00a929870c4c3d6d8e4d1adc5f04db3aac1b
+size 21084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..5ac838f4f39b08b5b8a5bca10f4809b89b5cd246
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:663310d22dab14745f7a6067ba2b02f7c390474250b9a2d1330cbb55728e2f71
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..fa9db15c0bc4eeeddca2f786cd6c6c5de31a6dbd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:403e84ad717168f90758115872ac35f74a5a5330e4a76202ec461d82b9cbab79
+size 540706
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..831cd04e107d9527ea3385521d0b581dd65a8a6b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/base_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3007a4de228bc656964eea51203a2d04ae27a8be57f397b76e83a537f0df866a
+size 240784
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm.dae
new file mode 100644
index 0000000000000000000000000000000000000000..2483da94c077d3cb5a078906a71730bda836e628
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6cb2f094ffba59f124f70881cabea985abfd399f2fbbe76fba7c18d2ece943b9
+size 1140936
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm.stl
new file mode 100644
index 0000000000000000000000000000000000000000..f45e03826c43012b7d691055029ff858dd11173d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e7423cab807c34160ec4f770daee5e747d70e777eb01b7beeace2b8c5751816
+size 53284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..f952494dc4dc7877b9a59a4a451a068ab06ed939
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34ce0be2e1632e4fe37bf494e3893c75f2fc090be4290b65927658a123af7e0a
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..1f2d17cad0bbe2c5d6c04fb61b63d5447f734726
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7496b538aeffefcdeb3885f339e3c636516a123d7ab85ef8c75530f203bb00e9
+size 1466045
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..28afda347a5f3a3750bd9307e772dc49ce6b190f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/forearm_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:edc0313de11350874137be754939293c2b4ae810cad6aace786692b1aed8a640
+size 648934
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/pedestal.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/pedestal.dae
new file mode 100644
index 0000000000000000000000000000000000000000..3d1b182cfdd00212645903b3980e1597d0cbd76e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/pedestal.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9a74be4ae114c9acafccf68e1e49d8fd815ec030012ce1b60bdcf9b30db49f5
+size 2734652
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/pedestal.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/pedestal.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..679543fc2320f420d0679415d38b000ac9b756b7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/pedestal.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:88a713a17094f06d1c8a540dc80b92ab9ee1d504564ac879e4d181f34af46ae8
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/pedestal.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/pedestal.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4ee6a54adcfd8339d2a4e7ba79891734fe60b3ca
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/pedestal.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2ddd6d21418e64e98195e454b8e684b20ca15baa302dec8ce16f19fb0a48498
+size 2817264
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder.dae
new file mode 100644
index 0000000000000000000000000000000000000000..d942729f28425811f372408f256c60f5c3279f35
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2b58abbef50ce03d4704465d3b619a7da5b2ecb2efc236eadfd221116cbbef0
+size 1797082
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder.stl
new file mode 100644
index 0000000000000000000000000000000000000000..2afa86568142f0b614e766e149fd19b5ca23ff43
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ceb92532177daa77682f5fbd628e01c2137d168f949a7a706ce1dabe9f002387
+size 70084
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..93b0626b2e49a4d9b4d774ebdc5bbacf3a693caa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ea38dba4d482ea7d337c2aa84765da479f196af8428033f89d4b4172e396e55
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..bcb6e0147c2ff39737753215242c119cb292e487
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bbce6bd41475f84a0c5fecb83791474558070d06caaa82196a558538c32f11c
+size 2416493
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..35dc031611f777189e0d801a5a8c504ee744a75e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/shoulder_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cb64e003df6d9fd0c3716bcc662c4363bad8b66de2bd564de66156aad0c67403
+size 1056884
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm.dae
new file mode 100644
index 0000000000000000000000000000000000000000..712cceaaaffa5c2e96e060dbd4304faeee0ef8b6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6486526456e11585a41080958b97a0c8821da856e93b2059c74b07f8102bf6cd
+size 3082485
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm.stl
new file mode 100644
index 0000000000000000000000000000000000000000..d0535c386a41cb0f1086573e24660180d3635c2d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee893044caf00075cb55b4cf666d1f1311c7979786212a501009f33bee945209
+size 99684
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..b84076060ec8dc8d69265a37606bdc57c9092fd2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f6867485b58212fbbbc719e169c447cfbe96a11d695dfb20868f9c8fafd3461
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4ee9a68d1cf5453558a95d6d23cf82b2f0bdc16e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6f7168e435e18220cdf0d61cf67e5ec140d74651b7f9eea6a1138afccd5a696
+size 4049840
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..2c0c6cd88a2cb65056d64b0f64f77581393c8700
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/upperarm_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e3805393009270b717f7eb4564d0fd9fb75b1e2199158a82dbd7216ea79f29f2
+size 1706034
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1.dae
new file mode 100644
index 0000000000000000000000000000000000000000..4eaa7dba1291817b44551da862aadb5c7d4823ba
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c36649cf4deec6da427d72d45163a059dcb668eb61aac760d5f2f979948fa13a
+size 1334662
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1.stl
new file mode 100644
index 0000000000000000000000000000000000000000..a4c0a83322115f8d31fbcf8a8a83968fd01497fa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8c9f9337b6fd98c75f052e96de10e14a107ddb6874ba6b904e546f8a4e4f43a
+size 59584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..f11c62e9b0a5379da28889c2c8fd3f20e97f68b9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95f2b261b30a808efdecda2abd67b18fb695ed9a0691586f299274de35617327
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..fde49bf1532ce3d03160f487fbcb7e8e456855aa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6eaeff6ea4f4641e142f0f8cd191c01d3d1547df095e5720671e1c172872bd8
+size 1800469
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..83d5d014053a35dcf10f05768f722b0275de98a5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist1_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c9d1740c05ace364f4a24447a0d819da48755b31235242c69c8bc15ea5c61d6
+size 806284
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2.dae
new file mode 100644
index 0000000000000000000000000000000000000000..351f8942d4d382b536eb75fe4570e97b48be4596
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c99f5538357a7b42a9b207a12e52e713dd3eb7e587645144bde7ab31c1cfc76b
+size 1554838
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2.stl
new file mode 100644
index 0000000000000000000000000000000000000000..44ce896e6b5194de0a569779e5dd2e8766552685
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2964a63f60ce3e3cf3ad55bcf190d7876d50e373cb64b70a57cea5885eaf3c86
+size 67584
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..c6ad2d69c16350e3d807cce64fad324f526c97fa
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d711567f9decdeb7f44b848e75c44cd119aab9422b1ae48cb2f5ccd6747aa13
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..d4e506eaf1b7ac02bc9c4d2b38536a850aeb31f6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d92696777d24520869088c69615f0a2a4bcbbcb8b6675d2e97d3e8ee3a2e85a
+size 2099344
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..928005c93f54aa784dd9d66003100469d61b4028
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist2_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c16241088d427c1fd263be78220908b530ed296682206ffe9430b36c84b4ab74
+size 946784
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3.dae b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3.dae
new file mode 100644
index 0000000000000000000000000000000000000000..4a91d5b7bd2c2709dc0ddbc0a8391327570fbcc7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3.dae
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f5d4f544ae72698a41f75ed4897a76386a91b814dff22e30e93a5bd105a717c
+size 66076
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3.stl
new file mode 100644
index 0000000000000000000000000000000000000000..2db25af90ecf3a43388f22482dd8344c34380df5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:83b3666b4ae2badd54af0d2c25a921682ecbc29e849eec646c3ed55fb74c78a3
+size 7184
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..9d87daa5e45f71ea19edb0c740285eb916e0cd65
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57585ece09efa8e1722f23649baa44b67758dcbb08662dd248da0e0b8e5e0858
+size 236
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..26707f2f604136335a06a01531b7f50e89773438
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:caa699b340ddbb85003d2356ac41e9c020585c77d42d9e268414c2665ef92b97
+size 92076
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3_vis.stl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3_vis.stl
new file mode 100644
index 0000000000000000000000000000000000000000..7a0079ec3cdc20d8224d4dd81f7ec645540399f4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/meshes/wrist3_vis.stl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:275a0adeb5737c0ea951036071826a348d98c464bf0d5000a896935241e65712
+size 45634
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/no_texture_robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/no_texture_robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..d5fd5f77e9c1053110b6cd381d5fb25bdb8630db
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/no_texture_robot.xml
@@ -0,0 +1,80 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..777a0cdf4c41894fe746f81e6f94085767caaed4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:27e8ad5a2fbbd8dd54a7308654a1224451b7c8ecd05b39cc461bf9c5974f841b
+size 421
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..62412afd1e36c4545790155678a2d7216f6f71d8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0297e6e882404a2e1f34d4d18e0b4a1c4d9d4d08d5360b20981c8bf15bb75666
+size 344266
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis/base_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis/base_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..2c22fd6e2346ec4cb4262cc77d0850f068281f66
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis/base_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:192729de151e0982652acca1a97b64337a9ef22bd3ddf2420fa2b8e782afd3f9
+size 555556
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis/base_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis/base_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..17b44e9f3b19e43d3569314d87945f53206ec711
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/base_vis/base_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:436805384ad1f8f2ccaf028fed33596df4e156656fe4eeb73bfb1d32f51ccebd
+size 638730
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..330f355c55c28e067ecea287c0fb3b8c26dc664b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66c34739f2772474b80b3f86381ca2810b4bc7e036dd6f4311d59661cf85bee7
+size 775
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..fb30cdc5b32c36883200b315e93b6ba437425b2f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e5a706619c54ebc0e74d0cdc838f63066a47df43156cb431813243f86e14946
+size 1134710
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5f16f8d858ed1d92ce0c531b10b9d0f290bc2793
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db3cc5aabd8b4ff3e2392f7aa989a1ac3829a8c692cc84a44fc6d161e0c9e25f
+size 1692114
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..86331383f406767a7551016407bfe5ddcdc0344c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:593487745160c53175b1329aa7a56275fd4ffad24c86e56966fbcaafe00cb74e
+size 76433
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..1aeb038fb9310f7928a1baf6a019f0d50a9ffd8f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f68fd607a6166a0e2fcb162f472db24aa7008a369d878b700d5ca3f499e3f850
+size 933173
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..64b386b9c40b280be19cc8bf738d8f67bc8406ee
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/forearm_vis/forearm_vis_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b0fe9c0c9eb1b2612906effab179846c71b82c910ea73dfac3198ad1854189fd
+size 1096820
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..7b6c830eb2405b236f1024d32554617033e6ec3a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d47f710de6c043902d31ef3133711e042ebc2dff1b4a6aae871a557cea494488
+size 605
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c00b27d0d135a109ad67ec7671d135d555181d93
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8e910c05bfc38f0d8b31706074ddbc82e041d2e6feb07467539e5a537033c819
+size 1796581
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis/shoulder_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis/shoulder_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..0f708fabf54a029e23c7f7af6a8cddbfa8da7cea
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis/shoulder_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e6f46fbcdc3eb91856d29a54acbbb39242f422dcc62b6d5590685b057cf8be1a
+size 2946979
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis/shoulder_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis/shoulder_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..6b7753b61129c27c6e5b963ab89533b079024e6b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis/shoulder_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:922264d6abbaaa4d7cce453b3371dc7b9dfc8dd46961a64c835b431331e6643d
+size 569961
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis/shoulder_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis/shoulder_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5b4fa73f4d0cd47cf2d6be8476bc01b72a9506b4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/shoulder_vis/shoulder_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c8b5432ca017f953da2ce264bfb8de2ac2b61eb5c85d472d06d4b9da3c3c9ab8
+size 2441153
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..a4aa777df8ceaa3f387f491b3ee69e800af61cbe
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fd967fc19fbe1cf8e88fe75ec15d8d61e6c80358f1fbf23bcc534d2684fdbd0
+size 791
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..fb292eafbc8057b387bf59e4924c7caeabd401c6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07a32b121699a5e282caee0478f9ce1848715549c6afe6654bd4c9b2eb711c12
+size 3158403
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..3e013b959e6bb18baaa914b67bb76ec0f5e18811
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e53e6302c7937d8c0853fdd600bcd4adf3767f586ed9a12170b4159f8f4a5b1
+size 156260
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..d7764e464e8d940498e6a7684db933efe751afea
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0ce040b7f759bbeb6874005457a84239d02690c7eb22fc85d4876eb258ef4c70
+size 1090424
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..4db6b9afff037cccbe6f54ef1a76c0101caa66bc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ebc641d2d76b981acd673fb6644895ce56a8c10e0109cbc4a8cc8c59e540e862
+size 3684774
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_3.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_3.obj
new file mode 100644
index 0000000000000000000000000000000000000000..bcf9440ab9ce851f3cae0ac12f3d78e8e4ffabfe
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/upperarm_vis/upperarm_vis_3.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0a965ee78d60112a72bcc33d173d6669f09a5854ab157c2c6df1b3f5857bbc3f
+size 5345375
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..c124e07a96b651de2f2b799748c947de84c62bdf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b29367ba1b140f1540c05ae8eb0837ccd5f7fb906bd0a323b6e7d8e8d1617af5
+size 605
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..746d1a99369f672a0d615fdb803667e96eebee3f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:022c689e3363f7fbd147ca4c0495489bd5c6e51f06f02c5c36ab0110b2aa7aa3
+size 1319390
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis/wrist1_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis/wrist1_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..b45db7eca2622ab5ca0e9cf6f9cf22ad3d022340
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis/wrist1_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44d267ff1e725fc4a59ed6ab2f2eeb58356871d98d868d8ba08555bbd8b0af5d
+size 278772
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis/wrist1_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis/wrist1_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..c19ea037389a5411951b71b8fe79f4be61c95844
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis/wrist1_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f04a25bc0827e903c9d001a16173e33197641280317b1d9e78e473f476052764
+size 2470993
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis/wrist1_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis/wrist1_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..ffbdabf5c24241ba6e6d1f758ecf698bb9891f8e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist1_vis/wrist1_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:23e82bbd4d6d21084d1a72d50f4d027d413ac0cb23edcb0b1c76f6c162da6a11
+size 1658162
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..f090d8e29b3f4c781b7de56f5c85c767bcdfa938
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7d15722d0ecaedb182f661dbabe67e0fa89a8fb05348458cf7d70f65044d1da2
+size 605
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..830afb040ee47f04d8728ba8019bf60159a83c5f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:715a3abe4cd81466113069a39663ca07249e795be712e613b21127b294142bdc
+size 1548411
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis/wrist2_vis_0.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis/wrist2_vis_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..8d9ed22f08b55ec5f52d60e1aca9ae6ae1a623af
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis/wrist2_vis_0.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6482ce9227c3a4e0bc5af577c33008fda74723b7e8b72b5ffa6a2b5bf42651c3
+size 846654
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis/wrist2_vis_1.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis/wrist2_vis_1.obj
new file mode 100644
index 0000000000000000000000000000000000000000..68a3731f7d1d5e94eea2a3cb45f66794f4531b7d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis/wrist2_vis_1.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34c774dab175b810eb364b0955feaaa9648ad33bc4cd445f2d87c7a5f0d8faad
+size 2156472
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis/wrist2_vis_2.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis/wrist2_vis_2.obj
new file mode 100644
index 0000000000000000000000000000000000000000..13f899867fbb314fa3367dce49b776a81a5687f2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist2_vis/wrist2_vis_2.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a382e17343660e5757f03fd1f8d5dbd07cbcb5828b3c838815bdb1bf0c9d3004
+size 2141246
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist3_vis.mtl b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist3_vis.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..c79271118c99ead228f3313844b968e893938cc6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist3_vis.mtl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d3ebe91449c0ba69e2d91be33a24185d8ded3f2f975a05b3c504752db1f1692
+size 237
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist3_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist3_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..24af72a1aa0258d1cab7ededadb8658def10a961
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist3_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a36deaacff34e810bbf86381ff2bce5cd3ecbced6486e3336c51649b2067881f
+size 58971
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist3_vis/wrist3_vis.obj b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist3_vis/wrist3_vis.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5955b8a7a6536575d850a8edb619264553c65886
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/obj_meshes/wrist3_vis/wrist3_vis.obj
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:953c3ab8ff91983ba79cf139eec47853ded3e53943c366aaed84f60e62172e10
+size 211663
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/robot.xml b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/robot.xml
new file mode 100644
index 0000000000000000000000000000000000000000..1d0f101cb287fb286dda3be091a3432d137e889e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/robots/ur5e/robot.xml
@@ -0,0 +1,132 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/blue-wood.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/blue-wood.png
new file mode 100644
index 0000000000000000000000000000000000000000..ca06668e5038d9f5bea19a8045026e0e50cbc94e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/blue-wood.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b1eab21395a05d7463751f9075b16bc4fba8fd4be42938bda25ea634a345884
+size 201294
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/brass-ambra.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/brass-ambra.png
new file mode 100644
index 0000000000000000000000000000000000000000..1221dbf0dd2471cc3d7318a49b66fe03b06f3649
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/brass-ambra.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:246007818045b92276631294bb61d797806f2381d7044e98b6e3c4465f8e085c
+size 1720843
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/bread.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/bread.png
new file mode 100644
index 0000000000000000000000000000000000000000..f23b705dc737519490d20f7f8fce2b1676d8a15d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/bread.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:22814a06ae3c7dc097801887e6aec97f830b7368b611cab56d1c4ad5ae57bddc
+size 518677
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/can.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/can.png
new file mode 100644
index 0000000000000000000000000000000000000000..1cbc5117af16dce9ef31b567f240f993f8aece0a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/can.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d1b29ff348a4425f84bb0de6b46139446b7e950058348d52fced0a707558e02
+size 586402
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/ceramic.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/ceramic.png
new file mode 100644
index 0000000000000000000000000000000000000000..891c50a69b45e57cc9befa6f8fcc579d35552fc4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/ceramic.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:beef464a2aa9143ba6a25a44adcbfd1f1aff16c5969edbf08019b61025388402
+size 1442030
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/cereal.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/cereal.png
new file mode 100644
index 0000000000000000000000000000000000000000..5486bb143fe9407be3e5f9d34f30e95d61b543cb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/cereal.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa0be8e701b41b3a7284bcd1e6e18adb541d007d00b2dbfe69772620bea790c8
+size 542717
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/clay.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/clay.png
new file mode 100644
index 0000000000000000000000000000000000000000..d264b69c1a3d4be139085c9c6bacab5ef34c7049
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/clay.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0ea2ced7f0f07996292314e3f3f5df68c4d8c5b2cbf2180c899ed269906fdcd2
+size 634465
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/cream-plaster.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/cream-plaster.png
new file mode 100644
index 0000000000000000000000000000000000000000..d77333eb1bd7d44057939b3c12c8d9a620e4a8e1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/cream-plaster.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f989033f089158b2c8f0f1f7881f6922b018a5b154558b2edcc9f41a21ab983
+size 696397
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/dark-wood.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/dark-wood.png
new file mode 100644
index 0000000000000000000000000000000000000000..09784cca1e9f2de13f0306189a9ad21aaa47241b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/dark-wood.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5da8dbf7af06ecd251840aceec25fedcfa98267041cea289b2829c982e008dc3
+size 239227
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/dirt.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/dirt.png
new file mode 100644
index 0000000000000000000000000000000000000000..eaeda47e2951d5cf35cc9881571235ebeb3fcb4c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/dirt.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b410966a8e0c1d52e2847bdb61aa51c6291efcd7e92839923637d0dcdb91e973
+size 533543
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/glass.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/glass.png
new file mode 100644
index 0000000000000000000000000000000000000000..c48c108132eb21a8f4a17a06d778b04d5c22dcb9
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/glass.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e12c4e1ed663ba3b690a701626e80189be02228f6d29c985273f15da4de423a
+size 89051
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/gray-felt.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/gray-felt.png
new file mode 100644
index 0000000000000000000000000000000000000000..66ee5d48f22f4f66a9fe70865865063b14e2f6b2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/gray-felt.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04c1b483f04be0bd3d26a42ca530c05244cf4b52c0ccb027dbb8fc2bee41ca05
+size 1037473
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/gray-plaster.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/gray-plaster.png
new file mode 100644
index 0000000000000000000000000000000000000000..fcf71d0d1dcff614a27b5d521f0986c197cc29d2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/gray-plaster.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7ee7520ebf6118051ced809cb9754983e7537a40f8d5271d5cb2fabdd71c33bd
+size 474599
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/gray-woodgrain.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/gray-woodgrain.png
new file mode 100644
index 0000000000000000000000000000000000000000..e72ed266cce2951efb4c2f125b4defa31719d833
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/gray-woodgrain.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9e65c147f2683dbcfe6279b359db649aef0cb40ef8b1e3ed5ba6f29fab6b3d0
+size 267843
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/green-wood.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/green-wood.png
new file mode 100644
index 0000000000000000000000000000000000000000..2a6cc1080f72267c3640eadd87fcb5418eeeb4e7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/green-wood.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e4fd42ec6dea58f46aa96fc1d78d57c52450a6497dcad887ae8e2ff8c4af3797
+size 83480
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/lemon.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/lemon.png
new file mode 100644
index 0000000000000000000000000000000000000000..6e9743ccc21070f6a053e27517d8461b6fbf1217
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/lemon.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93a21938e5fe8d62a042c5591b13b9bec75d94791f7d08f375beb60fce523610
+size 1919770
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/light-gray-floor-tile.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/light-gray-floor-tile.png
new file mode 100644
index 0000000000000000000000000000000000000000..70a290f75e6545ae5bc745acadf63566fadd6c2c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/light-gray-floor-tile.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2aae4ba3b38c3ab9851eca9ad6e76cef9357ef75b3e9a4b280f694ca30acbc5
+size 73373
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/light-gray-plaster.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/light-gray-plaster.png
new file mode 100644
index 0000000000000000000000000000000000000000..8553f5e438ccf68cfe9ce9300ce8c94712ba1e8f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/light-gray-plaster.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0bdaf13b44ab6ef3122451f6b3853e40bdf20bb9be7662649683ba525dfe191
+size 516639
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/light-wood.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/light-wood.png
new file mode 100644
index 0000000000000000000000000000000000000000..92f9d6aa33937bb75a49fb9543e0c6a74a88aff6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/light-wood.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:967be64bdb170d36283bc43c2ae5d525ec88f78bc837dc7d0efc9a28d68caad1
+size 801173
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/metal.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/metal.png
new file mode 100644
index 0000000000000000000000000000000000000000..f5260ff57151e704fb133e4e0ea88af4a455ee24
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/metal.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93784216a66b52be9c8ccc3732e50f511815291e32f34f97bd139ed7229061fa
+size 196878
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/pink-plaster.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/pink-plaster.png
new file mode 100644
index 0000000000000000000000000000000000000000..08abc2bc4cacd7d533d1275e3275626c15195eaf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/pink-plaster.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6f8435c8aec555fa89ee72d472d652f31ce1afc0e07372e9857b0aa50f6e029
+size 553251
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/red-wood.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/red-wood.png
new file mode 100644
index 0000000000000000000000000000000000000000..1d1e2e39d020002dc548ff27c7c1b6184d025583
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/red-wood.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d42411e594d6be26c4876ada623b163af6621402462b244d86eaf6e72aca6494
+size 1797019
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/soda.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/soda.png
new file mode 100644
index 0000000000000000000000000000000000000000..5d48f08b91de2820aeb9687371947d4034f11809
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/soda.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2774d29c6c16e1a3b3156e6679897a1695127ec59463d2de1a7ffd7f11b41b4d
+size 627792
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/steel-brushed.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/steel-brushed.png
new file mode 100644
index 0000000000000000000000000000000000000000..a072534801ccb26c8512195f490edc9f1b8295ab
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/steel-brushed.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:43f30e7a2d0e8085aa15f58623b1753a003f93e5fe336eeaf65c146f955a81d5
+size 275192
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/steel-scratched.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/steel-scratched.png
new file mode 100644
index 0000000000000000000000000000000000000000..96c24dd05d9dd8ac7be045baf71c668582b1cd7d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/steel-scratched.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e22a111d516cc84addde9b2b18353d8a6ec521b9b79ec5e9c169e2d9000d3fe
+size 391216
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/white-bricks.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/white-bricks.png
new file mode 100644
index 0000000000000000000000000000000000000000..04ce966e64c68766b372a84bfc57323e5bbf7b92
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/white-bricks.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:864e2bfb462939212c4141a060eb4c3f34c891b6909873797b08f8c4185433d1
+size 1458770
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/white-plaster.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/white-plaster.png
new file mode 100644
index 0000000000000000000000000000000000000000..7972a9b8fc1bf5a9ee03919c0cf6ee6f04513ba6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/white-plaster.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a35c03edcef74c940400dd39e2384d05169d1d15ac247166b3b7a930acd70f03
+size 624094
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/wood-tiles.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/wood-tiles.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9988e92a4aa98c43e46a274bc9aaa49096a0cf5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/wood-tiles.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18d87069529422d832559ce969858576dd46f83efc3dfd9e5ffbf26d740be123
+size 1460776
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/wood-varnished-panels.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/wood-varnished-panels.png
new file mode 100644
index 0000000000000000000000000000000000000000..28c7819137619bf1eaa2ad1fd77612c4b14c255f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/wood-varnished-panels.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:67b48dc99b89004d98cb0811f170d39993e7e205cc8d5bd3d01a4be137e33a85
+size 466710
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/yellow-plaster.png b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/yellow-plaster.png
new file mode 100644
index 0000000000000000000000000000000000000000..0610295bf881b2e6d78f5fc72811e4cd2f1b0702
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/yellow-plaster.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f98027e5dd4ad845a1a20f302b183c62b8e8fab54799634c5143c1dde0d0362
+size 477879
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/base.py b/phantom/submodules/phantom-robosuite/robosuite/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e6f98d47904d4c05c8e8c34fb297d8cf570398
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/base.py
@@ -0,0 +1,696 @@
+import io
+import os
+import xml.dom.minidom
+import xml.etree.ElementTree as ET
+
+import robosuite.macros as macros
+from robosuite.utils import XMLError
+from robosuite.utils.mjcf_utils import (
+ _element_filter,
+ add_material,
+ add_prefix,
+ find_elements,
+ recolor_collision_geoms,
+ sort_elements,
+ string_to_array,
+)
+
+
+class MujocoXML(object):
+ """
+ Base class of Mujoco xml file
+ Wraps around ElementTree and provides additional functionality for merging different models.
+ Specially, we keep track of , and
+
+ When initialized, loads a mujoco xml from file.
+
+ Args:
+ fname (str): path to the MJCF xml file.
+ """
+
+ def __init__(self, fname):
+ self.file = fname
+ self.folder = os.path.dirname(fname)
+ self.tree = ET.parse(fname)
+ self.root = self.tree.getroot()
+ self.worldbody = self.create_default_element("worldbody")
+ self.actuator = self.create_default_element("actuator")
+ self.sensor = self.create_default_element("sensor")
+ self.asset = self.create_default_element("asset")
+ self.tendon = self.create_default_element("tendon")
+ self.equality = self.create_default_element("equality")
+ self.contact = self.create_default_element("contact")
+
+ # Parse any default classes and replace them inline
+ default = self.create_default_element("default")
+ default_classes = self._get_default_classes(default)
+ self._replace_defaults_inline(default_dic=default_classes)
+
+ # Remove original default classes
+ self.root.remove(default)
+
+ self.resolve_asset_dependency()
+
+ def resolve_asset_dependency(self):
+ """
+ Converts every file dependency into absolute path so when we merge we don't break things.
+ """
+
+ for node in self.asset.findall("./*[@file]"):
+ file = node.get("file")
+ abs_path = os.path.abspath(self.folder)
+ abs_path = os.path.join(abs_path, file)
+ node.set("file", abs_path)
+
+ def create_default_element(self, name):
+ """
+ Creates a <@name/> tag under root if there is none.
+
+ Args:
+ name (str): Name to generate default element
+
+ Returns:
+ ET.Element: Node that was created
+ """
+
+ found = self.root.find(name)
+ if found is not None:
+ return found
+ ele = ET.Element(name)
+ self.root.append(ele)
+ return ele
+
+ def merge(self, others, merge_body="default"):
+ """
+ Default merge method.
+
+ Args:
+ others (MujocoXML or list of MujocoXML): other xmls to merge into this one
+ raises XML error if @others is not a MujocoXML instance.
+ merges , and of @others into @self
+ merge_body (None or str): If set, will merge child bodies of @others. Default is "default", which
+ corresponds to the root worldbody for this XML. Otherwise, should be an existing body name
+ that exists in this XML. None results in no merging of @other's bodies in its worldbody.
+
+ Raises:
+ XMLError: [Invalid XML instance]
+ """
+ if type(others) is not list:
+ others = [others]
+ for idx, other in enumerate(others):
+ if not isinstance(other, MujocoXML):
+ raise XMLError("{} is not a MujocoXML instance.".format(type(other)))
+ if merge_body is not None:
+ root = (
+ self.worldbody
+ if merge_body == "default"
+ else find_elements(
+ root=self.worldbody, tags="body", attribs={"name": merge_body}, return_first=True
+ )
+ )
+ for body in other.worldbody:
+ root.append(body)
+ self.merge_assets(other)
+ for one_actuator in other.actuator:
+ self.actuator.append(one_actuator)
+ for one_sensor in other.sensor:
+ self.sensor.append(one_sensor)
+ for one_tendon in other.tendon:
+ self.tendon.append(one_tendon)
+ for one_equality in other.equality:
+ self.equality.append(one_equality)
+ for one_contact in other.contact:
+ self.contact.append(one_contact)
+
+ def get_model(self, mode="mujoco"):
+ """
+ Generates a MjModel instance from the current xml tree.
+
+ Args:
+ mode (str): Mode with which to interpret xml tree
+
+ Returns:
+ MjModel: generated model from xml
+
+ Raises:
+ ValueError: [Invalid mode]
+ """
+
+ available_modes = ["mujoco"]
+ with io.StringIO() as string:
+ string.write(ET.tostring(self.root, encoding="unicode"))
+ if mode == "mujoco":
+ import mujoco
+
+ model = mujoco.MjModel.from_xml_string(string.getvalue())
+ return model
+ raise ValueError("Unkown model mode: {}. Available options are: {}".format(mode, ",".join(available_modes)))
+
+ def get_xml(self):
+ """
+ Reads a string of the MJCF XML file.
+
+ Returns:
+ str: XML tree read in from file
+ """
+ with io.StringIO() as string:
+ string.write(ET.tostring(self.root, encoding="unicode"))
+ return string.getvalue()
+
+ def save_model(self, fname, pretty=False):
+ """
+ Saves the xml to file.
+
+ Args:
+ fname (str): output file location
+ pretty (bool): If True, (attempts!! to) pretty print the output
+ """
+ with open(fname, "w") as f:
+ xml_str = ET.tostring(self.root, encoding="unicode")
+ if pretty:
+ parsed_xml = xml.dom.minidom.parseString(xml_str)
+ xml_str = parsed_xml.toprettyxml(newl="")
+ f.write(xml_str)
+
+ def merge_assets(self, other):
+ """
+ Merges @other's assets in a custom logic.
+
+ Args:
+ other (MujocoXML or MujocoObject): other xml file whose assets will be merged into this one
+ """
+ for asset in other.asset:
+ if (
+ find_elements(root=self.asset, tags=asset.tag, attribs={"name": asset.get("name")}, return_first=True)
+ is None
+ ):
+ self.asset.append(asset)
+
+ def get_element_names(self, root, element_type):
+ """
+ Searches recursively through the @root and returns a list of names of the specified @element_type
+
+ Args:
+ root (ET.Element): Root of the xml element tree to start recursively searching through
+ (e.g.: `self.worldbody`)
+ element_type (str): Name of element to return names of. (e.g.: "site", "geom", etc.)
+
+ Returns:
+ list: names that correspond to the specified @element_type
+ """
+ names = []
+ for child in root:
+ if child.tag == element_type:
+ names.append(child.get("name"))
+ names += self.get_element_names(child, element_type)
+ return names
+
+ @staticmethod
+ def _get_default_classes(default):
+ """
+ Utility method to convert all default tags into a nested dictionary of values -- this will be used to replace
+ all elements' class tags inline with the appropriate defaults if not specified.
+
+ Args:
+ default (ET.Element): Nested default tag XML root.
+
+ Returns:
+ dict: Nested dictionary, where each default class name is mapped to its own dict mapping element tag names
+ (e.g.: geom, site, etc.) to the set of default attributes for that tag type
+ """
+ # Create nested dict to return
+ default_dic = {}
+ # Parse the default tag accordingly
+ for cls in default:
+ default_dic[cls.get("class")] = {child.tag: child for child in cls}
+ return default_dic
+
+ def _replace_defaults_inline(self, default_dic, root=None):
+ """
+ Utility method to replace all default class attributes recursively in the XML tree starting from @root
+ with the corresponding defaults in @default_dic if they are not explicitly specified for ta given element.
+
+ Args:
+ root (ET.Element): Root of the xml element tree to start recursively replacing defaults. Only is used by
+ recursive calls
+ default_dic (dict): Nested dictionary, where each default class name is mapped to its own dict mapping
+ element tag names (e.g.: geom, site, etc.) to the set of default attributes for that tag type
+ """
+ # If root is None, this is the top level call -- replace root with self.root
+ if root is None:
+ root = self.root
+ # Check this current element if it contains any class elements
+ cls_name = root.attrib.pop("class", None)
+ if cls_name is not None:
+ # If the tag for this element is contained in our default dic, we add any defaults that are not
+ # explicitly specified in this
+ tag_attrs = default_dic[cls_name].get(root.tag, None)
+ if tag_attrs is not None:
+ for k, v in tag_attrs.items():
+ if root.get(k, None) is None:
+ root.set(k, v)
+ # Loop through all child elements
+ for child in root:
+ self._replace_defaults_inline(default_dic=default_dic, root=child)
+
+ @property
+ def name(self):
+ """
+ Returns name of this MujocoXML
+
+ Returns:
+ str: Name of this MujocoXML
+ """
+ return self.root.get("model")
+
+
+class MujocoModel(object):
+ """
+ Base class for all simulation models used in mujoco.
+
+ Standardizes core API for accessing models' relevant geoms, names, etc.
+ """
+
+ def correct_naming(self, names):
+ """
+ Corrects all strings in @names by adding the naming prefix to it and returns the name-corrected values
+
+ Args:
+ names (str, list, or dict): Name(s) to be corrected
+
+ Raises:
+ TypeError: [Invalid input type]
+ """
+ if type(names) is str:
+ return self.naming_prefix + names if not self.exclude_from_prefixing(names) else names
+ elif type(names) is list:
+ return [self.naming_prefix + name if not self.exclude_from_prefixing(name) else name for name in names]
+ elif type(names) is dict:
+ names = names.copy()
+ for key, val in names.items():
+ names[key] = self.correct_naming(val)
+ return names
+ else:
+ # Assumed to be type error
+ raise TypeError("Error: type of 'names' must be str, list, or dict!")
+
+ def set_sites_visibility(self, sim, visible):
+ """
+ Set all site visual states for this model.
+
+ Args:
+ sim (MjSim): Current active mujoco simulation instance
+ visible (bool): If True, will visualize model sites. Else, will hide the sites.
+ """
+ # Loop through all visualization geoms and set their alpha values appropriately
+ for vis_g in self.sites:
+ vis_g_id = sim.model.site_name2id(vis_g)
+ if (visible and sim.model.site_rgba[vis_g_id][3] < 0) or (
+ not visible and sim.model.site_rgba[vis_g_id][3] > 0
+ ):
+ # We toggle the alpha value
+ sim.model.site_rgba[vis_g_id][3] = -sim.model.site_rgba[vis_g_id][3]
+
+ def exclude_from_prefixing(self, inp):
+ """
+ A function that should take in an arbitrary input and return either True or False, determining whether the
+ corresponding name to @inp should have naming_prefix added to it. Must be defined by subclass.
+
+ Args:
+ inp (any): Arbitrary input, depending on subclass. Can be str, ET.Element, etc.
+
+ Returns:
+ bool: True if we should exclude the associated name(s) with @inp from being prefixed with naming_prefix
+ """
+ raise NotImplementedError
+
+ @property
+ def name(self):
+ """
+ Name for this model. Should be unique.
+
+ Returns:
+ str: Unique name for this model.
+ """
+ raise NotImplementedError
+
+ @property
+ def naming_prefix(self):
+ """
+ Generates a standardized prefix to prevent naming collisions
+
+ Returns:
+ str: Prefix unique to this model.
+ """
+ raise NotImplementedError
+
+ @property
+ def root_body(self):
+ """
+ Root body name for this model. This should correspond to the top-level body element in the equivalent mujoco xml
+ tree for this object.
+ """
+ raise NotImplementedError
+
+ @property
+ def bodies(self):
+ """
+ Returns:
+ list: Body names for this model
+ """
+ raise NotImplementedError
+
+ @property
+ def joints(self):
+ """
+ Returns:
+ list: Joint names for this model
+ """
+ raise NotImplementedError
+
+ @property
+ def actuators(self):
+ """
+ Returns:
+ list: Actuator names for this model
+ """
+ raise NotImplementedError
+
+ @property
+ def sites(self):
+ """
+ Returns:
+ list: Site names for this model
+ """
+ raise NotImplementedError
+
+ @property
+ def sensors(self):
+ """
+ Returns:
+ list: Sensor names for this model
+ """
+ raise NotImplementedError
+
+ @property
+ def contact_geoms(self):
+ """
+ List of names corresponding to the geoms used to determine contact with this model.
+
+ Returns:
+ list: relevant contact geoms for this model
+ """
+ raise NotImplementedError
+
+ @property
+ def visual_geoms(self):
+ """
+ List of names corresponding to the geoms used for visual rendering of this model.
+
+ Returns:
+ list: relevant visual geoms for this model
+ """
+ raise NotImplementedError
+
+ @property
+ def important_geoms(self):
+ """
+ Geoms corresponding to important components of this model. String keywords should be mapped to lists of geoms.
+
+ Returns:
+ dict of list: Important set of geoms, where each set of geoms are grouped as a list and are
+ organized by keyword string entries into a dict
+ """
+ raise NotImplementedError
+
+ @property
+ def important_sites(self):
+ """
+ Dict of sites corresponding to the important site geoms (e.g.: used to aid visualization during sim).
+
+ Returns:
+ dict: Important site geoms, where each specific geom name is mapped from keyword string entries
+ in the dict
+ """
+ raise NotImplementedError
+
+ @property
+ def important_sensors(self):
+ """
+ Dict of important sensors enabled for this model.
+
+ Returns:
+ dict: Important sensors for this model, where each specific sensor name is mapped from keyword string
+ entries in the dict
+ """
+ raise NotImplementedError
+
+ @property
+ def bottom_offset(self):
+ """
+ Returns vector from model root body to model bottom.
+ Useful for, e.g. placing models on a surface.
+ Must be defined by subclass.
+
+ Returns:
+ np.array: (dx, dy, dz) offset vector
+ """
+ raise NotImplementedError
+
+ @property
+ def top_offset(self):
+ """
+ Returns vector from model root body to model top.
+ Useful for, e.g. placing models on a surface.
+ Must be defined by subclass.
+
+ Returns:
+ np.array: (dx, dy, dz) offset vector
+ """
+ raise NotImplementedError
+
+ @property
+ def horizontal_radius(self):
+ """
+ Returns maximum distance from model root body to any radial point of the model.
+
+ Helps us put models programmatically without them flying away due to a huge initial contact force.
+ Must be defined by subclass.
+
+ Returns:
+ float: radius
+ """
+ raise NotImplementedError
+
+
+class MujocoXMLModel(MujocoXML, MujocoModel):
+ """
+ Base class for all MujocoModels that are based on a raw XML file.
+
+ Args:
+ fname (str): Path to relevant xml file from which to create this robot instance
+ idn (int or str): Number or some other unique identification string for this model instance
+ """
+
+ def __init__(self, fname, idn=0):
+ super().__init__(fname)
+
+ # Set id and add prefixes to all body names to prevent naming clashes
+ self.idn = idn
+
+ # Define other variables that get filled later
+ self.mount = None
+
+ # Define filter method to automatically add a default name to visual / collision geoms if encountered
+ group_mapping = {
+ None: "col",
+ "0": "col",
+ "1": "vis",
+ }
+ ctr_mapping = {
+ "col": 0,
+ "vis": 0,
+ }
+
+ def _add_default_name_filter(element, parent):
+ # Run default filter
+ filter_key = _element_filter(element=element, parent=parent)
+ # Also additionally modify element if it is (a) a geom and (b) has no name
+ if element.tag == "geom" and element.get("name") is None:
+ group = group_mapping[element.get("group")]
+ element.set("name", f"g{ctr_mapping[group]}_{group}")
+ ctr_mapping[group] += 1
+ # Return default filter key
+ return filter_key
+
+ # Parse element tree to get all relevant bodies, joints, actuators, and geom groups
+ self._elements = sort_elements(root=self.root, element_filter=_add_default_name_filter)
+ assert (
+ len(self._elements["root_body"]) == 1
+ ), "Invalid number of root bodies found for robot model. Expected 1," "got {}".format(
+ len(self._elements["root_body"])
+ )
+ self._elements["root_body"] = self._elements["root_body"][0]
+ self._elements["bodies"] = (
+ [self._elements["root_body"]] + self._elements["bodies"]
+ if "bodies" in self._elements
+ else [self._elements["root_body"]]
+ )
+ self._root_body = self._elements["root_body"].get("name")
+ self._bodies = [e.get("name") for e in self._elements.get("bodies", [])]
+ self._joints = [e.get("name") for e in self._elements.get("joints", [])]
+ self._actuators = [e.get("name") for e in self._elements.get("actuators", [])]
+ self._sites = [e.get("name") for e in self._elements.get("sites", [])]
+ self._sensors = [e.get("name") for e in self._elements.get("sensors", [])]
+ self._contact_geoms = [e.get("name") for e in self._elements.get("contact_geoms", [])]
+ self._visual_geoms = [e.get("name") for e in self._elements.get("visual_geoms", [])]
+ self._base_offset = string_to_array(self._elements["root_body"].get("pos", "0 0 0"))
+
+ # Update all xml element prefixes
+ add_prefix(root=self.root, prefix=self.naming_prefix, exclude=self.exclude_from_prefixing)
+
+ # Recolor all collision geoms appropriately
+ recolor_collision_geoms(root=self.worldbody, rgba=self.contact_geom_rgba)
+
+ # Add default materials
+ if macros.USING_INSTANCE_RANDOMIZATION:
+ tex_element, mat_element, _, used = add_material(root=self.worldbody, naming_prefix=self.naming_prefix)
+ # Only add if material / texture was actually used
+ if used:
+ self.asset.append(tex_element)
+ self.asset.append(mat_element)
+
+ def exclude_from_prefixing(self, inp):
+ """
+ By default, don't exclude any from being prefixed
+ """
+ return False
+
+ @property
+ def base_offset(self):
+ """
+ Provides position offset of root body.
+
+ Returns:
+ 3-array: (x,y,z) pos value of root_body body element. If no pos in element, returns all zeros.
+ """
+ return self._base_offset
+
+ @property
+ def name(self):
+ return "{}{}".format(type(self).__name__, self.idn)
+
+ @property
+ def naming_prefix(self):
+ return "{}_".format(self.idn)
+
+ @property
+ def root_body(self):
+ return self.correct_naming(self._root_body)
+
+ @property
+ def bodies(self):
+ return self.correct_naming(self._bodies)
+
+ @property
+ def joints(self):
+ return self.correct_naming(self._joints)
+
+ @property
+ def actuators(self):
+ return self.correct_naming(self._actuators)
+
+ @property
+ def sites(self):
+ return self.correct_naming(self._sites)
+
+ @property
+ def sensors(self):
+ return self.correct_naming(self._sensors)
+
+ @property
+ def contact_geoms(self):
+ return self.correct_naming(self._contact_geoms)
+
+ @property
+ def visual_geoms(self):
+ return self.correct_naming(self._visual_geoms)
+
+ @property
+ def important_sites(self):
+ return self.correct_naming(self._important_sites)
+
+ @property
+ def important_geoms(self):
+ return self.correct_naming(self._important_geoms)
+
+ @property
+ def important_sensors(self):
+ return self.correct_naming(self._important_sensors)
+
+ @property
+ def _important_sites(self):
+ """
+ Dict of sites corresponding to the important site geoms (e.g.: used to aid visualization during sim).
+
+ Returns:
+ dict: Important site geoms, where each specific geom name is mapped from keyword string entries
+ in the dict. Note that the mapped sites should be the RAW site names found directly in the XML file --
+ the naming prefix will be automatically added in the public method call
+ """
+ raise NotImplementedError
+
+ @property
+ def _important_geoms(self):
+ """
+ Geoms corresponding to important components of this model. String keywords should be mapped to lists of geoms.
+
+ Returns:
+ dict of list: Important set of geoms, where each set of geoms are grouped as a list and are
+ organized by keyword string entries into a dict. Note that the mapped geoms should be the RAW geom
+ names found directly in the XML file -- the naming prefix will be automatically added in the
+ public method call
+ """
+ raise NotImplementedError
+
+ @property
+ def _important_sensors(self):
+ """
+ Dict of important sensors enabled for this model.
+
+ Returns:
+ dict: Important sensors for this model, where each specific sensor name is mapped from keyword string
+ entries in the dict. Note that the mapped geoms should be the RAW sensor names found directly in the
+ XML file -- the naming prefix will be automatically added in the public method call
+ """
+ raise NotImplementedError
+
+ @property
+ def contact_geom_rgba(self):
+ """
+ RGBA color to assign to all contact geoms for this model
+
+ Returns:
+ 4-array: (r,g,b,a) values from 0 to 1 for this model's set of contact geoms
+ """
+ raise NotImplementedError
+
+ @property
+ def bottom_offset(self):
+ """
+ Returns vector from model root body to model bottom.
+ Useful for, e.g. placing models on a surface.
+ By default, this corresponds to the root_body's base offset.
+
+ Returns:
+ np.array: (dx, dy, dz) offset vector
+ """
+ return self.base_offset
+
+ @property
+ def top_offset(self):
+ raise NotImplementedError
+
+ @property
+ def horizontal_radius(self):
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2814e50ff77bc148113bd9e3824e8f63b56bfd1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/__init__.py
@@ -0,0 +1,31 @@
+from .gripper_model import GripperModel
+from .gripper_factory import gripper_factory
+from .gripper_tester import GripperTester
+
+from .panda_gripper import PandaGripper
+from .rethink_gripper import RethinkGripper
+from .robotiq_85_gripper import Robotiq85Gripper
+from .robotiq_gripper_85_real_kinova import Robotiq85GripperRealKinova
+from .robotiq_three_finger_gripper import RobotiqThreeFingerGripper, RobotiqThreeFingerDexterousGripper
+from .panda_gripper import PandaGripper
+from .jaco_three_finger_gripper import JacoThreeFingerGripper, JacoThreeFingerDexterousGripper
+from .robotiq_140_gripper import Robotiq140Gripper
+from .wiping_gripper import WipingGripper
+from .null_gripper import NullGripper
+
+
+GRIPPER_MAPPING = {
+ "RethinkGripper": RethinkGripper,
+ "PandaGripper": PandaGripper,
+ "JacoThreeFingerGripper": JacoThreeFingerGripper,
+ "JacoThreeFingerDexterousGripper": JacoThreeFingerDexterousGripper,
+ "WipingGripper": WipingGripper,
+ "Robotiq85Gripper": Robotiq85Gripper,
+ "Robotiq140Gripper": Robotiq140Gripper,
+ "RobotiqThreeFingerGripper": RobotiqThreeFingerGripper,
+ "RobotiqThreeFingerDexterousGripper": RobotiqThreeFingerDexterousGripper,
+ "Robotiq85GripperRealKinova": Robotiq85GripperRealKinova,
+ None: NullGripper,
+}
+
+ALL_GRIPPERS = GRIPPER_MAPPING.keys()
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/gripper_factory.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/gripper_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..133dce39b24ec446089e79267fe9fd8247259e22
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/gripper_factory.py
@@ -0,0 +1,29 @@
+"""
+Defines a string based method of initializing grippers
+"""
+
+
+def gripper_factory(name, idn=0):
+ """
+ Generator for grippers
+
+ Creates a GripperModel instance with the provided name.
+
+ Args:
+ name (None or str): the name of the gripper class
+ idn (int or str): Number or some other unique identification string for this gripper instance
+
+ Returns:
+ GripperModel: requested gripper instance
+
+ Raises:
+ XMLError: [invalid XML]
+ """
+ # Import GRIPPER_MAPPING at runtime so we avoid circular imports
+ from robosuite.models.grippers import ALL_GRIPPERS, GRIPPER_MAPPING
+
+ # Make sure gripper is valid
+ assert name in GRIPPER_MAPPING, f"Unknown gripper name: {name}. Valid options are: {ALL_GRIPPERS}"
+
+ # Generate gripper
+ return GRIPPER_MAPPING[name](idn=idn)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/gripper_model.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/gripper_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ebe18b3ae96aa2a0d04134071388dde1c75de10
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/gripper_model.py
@@ -0,0 +1,161 @@
+"""
+Defines the base class of all grippers
+"""
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.models.base import MujocoXMLModel
+from robosuite.utils.mjcf_utils import GRIPPER_COLLISION_COLOR, find_elements, string_to_array
+
+
+class GripperModel(MujocoXMLModel):
+ """
+ Base class for grippers
+
+ Args:
+ fname (str): Path to relevant xml file to create this gripper instance
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, fname, idn=0):
+ super().__init__(fname, idn=idn)
+
+ # Set variable to hold current action being outputted
+ self.current_action = np.zeros(self.dof)
+
+ # Grab gripper offset (string -> np.array -> elements [1, 2, 3, 0] (x, y, z, w))
+ # This is the comopunded rotation with the base body and the eef body as well!
+ base_quat = np.fromstring(self.worldbody[0].attrib.get("quat", "1 0 0 0"), dtype=np.float64, sep=" ")[
+ [1, 2, 3, 0]
+ ]
+ eef_element = find_elements(
+ root=self.root, tags="body", attribs={"name": self.correct_naming("eef")}, return_first=True
+ )
+ eef_relative_quat = string_to_array(eef_element.get("quat", "1 0 0 0"))[[1, 2, 3, 0]]
+ self.rotation_offset = T.quat_multiply(eef_relative_quat, base_quat)
+
+ def format_action(self, action):
+ """
+ Given (-1,1) abstract control as np-array
+ returns the (-1,1) control signals
+ for underlying actuators as 1-d np array
+ """
+ raise NotImplementedError
+
+ # -------------------------------------------------------------------------------------- #
+ # Properties: In general, these are the name-adjusted versions from the private #
+ # subclass implementations pulled from their respective raw xml files #
+ # -------------------------------------------------------------------------------------- #
+
+ @property
+ def naming_prefix(self):
+ return "gripper{}_".format(self.idn)
+
+ @property
+ def speed(self):
+ """
+ How quickly the gripper opens / closes
+
+ Returns:
+ float: Speed of the gripper
+ """
+ return 0.0
+
+ @property
+ def dof(self):
+ """
+ Defines the number of DOF of the gripper
+
+ Returns:
+ int: gripper DOF
+ """
+ return len(self._actuators)
+
+ @property
+ def bottom_offset(self):
+ return np.zeros(3)
+
+ @property
+ def top_offset(self):
+ return np.zeros(3)
+
+ @property
+ def horizontal_radius(self):
+ return 0
+
+ @property
+ def contact_geom_rgba(self):
+ return GRIPPER_COLLISION_COLOR
+
+ # -------------------------------------------------------------------------------------- #
+ # All subclasses must implement the following properties #
+ # -------------------------------------------------------------------------------------- #
+
+ @property
+ def init_qpos(self):
+ """
+ Defines the default rest (open) qpos of the gripper
+
+ Returns:
+ np.array: Default init qpos of this gripper
+ """
+ raise NotImplementedError
+
+ @property
+ def _important_sites(self):
+ """
+ Sites used to aid visualization by human. (usually "grip_site" and "grip_cylinder")
+ (and should be hidden from robots)
+
+ Returns:
+ dict:
+
+ :`'grip_site'`: Name of grip actuation intersection location site
+ :`'grip_cylinder'`: Name of grip actuation z-axis location site
+ :`'ee'`: Name of end effector site
+ :`'ee_x'`: Name of end effector site (x-axis)
+ :`'ee_y'`: Name of end effector site (y-axis)
+ :`'ee_z'`: Name of end effector site (z-axis)
+ """
+ return {
+ "grip_site": "grip_site",
+ "grip_cylinder": "grip_site_cylinder",
+ "ee": "ee",
+ "ee_x": "ee_x",
+ "ee_y": "ee_y",
+ "ee_z": "ee_z",
+ }
+
+ @property
+ def _important_geoms(self):
+ """
+ Geoms corresponding to important components of the gripper (by default, left_finger, right_finger,
+ left_fingerpad, right_fingerpad).
+ Note that these are the raw string names directly pulled from a gripper's corresponding XML file,
+ NOT the adjusted name with an auto-generated naming prefix
+
+ Note that this should be a dict of lists.
+
+ Returns:
+ dict of list: Raw XML important geoms, where each set of geoms are grouped as a list and are
+ organized by keyword string entries into a dict
+ """
+ return {
+ "left_finger": [],
+ "right_finger": [],
+ "left_fingerpad": [],
+ "right_fingerpad": [],
+ }
+
+ @property
+ def _important_sensors(self):
+ """
+ Sensor names for each gripper (usually "force_ee" and "torque_ee")
+
+ Returns:
+ dict:
+
+ :`'force_ee'`: Name of force eef sensor for this gripper
+ :`'torque_ee'`: Name of torque eef sensor for this gripper
+ """
+ return {sensor: sensor for sensor in ["force_ee", "torque_ee"]}
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/gripper_tester.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/gripper_tester.py
new file mode 100644
index 0000000000000000000000000000000000000000..e297f4f25c42322530857fb0f59c3a32b28406ce
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/gripper_tester.py
@@ -0,0 +1,235 @@
+"""
+Defines GripperTester that is used to test the physical properties of various grippers
+"""
+import xml.etree.ElementTree as ET
+
+import numpy as np
+
+import robosuite.macros as macros
+from robosuite.models.arenas.table_arena import TableArena
+from robosuite.models.objects import BoxObject
+from robosuite.models.world import MujocoWorldBase
+from robosuite.utils import OpenCVRenderer
+from robosuite.utils.binding_utils import MjSim
+from robosuite.utils.mjcf_utils import array_to_string, new_actuator, new_joint
+
+
+class GripperTester:
+ """
+ A class that is used to test gripper
+
+ Args:
+ gripper (GripperModel): A gripper instance to be tested
+ pos (str): (x y z) position to place the gripper in string form, e.g. '0 0 0.3'
+ quat (str): rotation to apply to gripper in string form, e.g. '0 0 1 0' to flip z axis
+ gripper_low_pos (float): controls the gipper y position, larger -> higher
+ gripper_high_pos (float): controls the gipper y high position larger -> higher,
+ must be larger than gripper_low_pos
+ box_size (None or 3-tuple of int): the size of the box to grasp, None defaults to [0.02, 0.02, 0.02]
+ box_density (int): the density of the box to grasp
+ step_time (int): the interval between two gripper actions
+ render (bool): if True, show rendering
+ """
+
+ def __init__(
+ self,
+ gripper,
+ pos,
+ quat,
+ gripper_low_pos,
+ gripper_high_pos,
+ box_size=None,
+ box_density=10000,
+ step_time=400,
+ render=True,
+ ):
+ # define viewer
+ self.viewer = None
+
+ world = MujocoWorldBase()
+ # Add a table
+ arena = TableArena(table_full_size=(0.4, 0.4, 0.1), table_offset=(0, 0, 0.1), has_legs=False)
+ world.merge(arena)
+
+ # Add a gripper
+ self.gripper = gripper
+ # Create another body with a slider joint to which we'll add this gripper
+ gripper_body = ET.Element("body")
+ gripper_body.set("pos", pos)
+ gripper_body.set("quat", quat) # flip z
+ gripper_body.append(new_joint(name="gripper_z_joint", type="slide", axis="0 0 -1", damping="50"))
+ # Add all gripper bodies to this higher level body
+ for body in gripper.worldbody:
+ gripper_body.append(body)
+ # Merge the all of the gripper tags except its bodies
+ world.merge(gripper, merge_body=None)
+ # Manually add the higher level body we created
+ world.worldbody.append(gripper_body)
+ # Create a new actuator to control our slider joint
+ world.actuator.append(new_actuator(joint="gripper_z_joint", act_type="position", name="gripper_z", kp="500"))
+
+ # Add an object for grasping
+ # density is in units kg / m3
+ TABLE_TOP = [0, 0, 0.09]
+ if box_size is None:
+ box_size = [0.02, 0.02, 0.02]
+ box_size = np.array(box_size)
+ self.cube = BoxObject(
+ name="object", size=box_size, rgba=[1, 0, 0, 1], friction=[1, 0.005, 0.0001], density=box_density
+ )
+ object_pos = np.array(TABLE_TOP + box_size * [0, 0, 1])
+ mujoco_object = self.cube.get_obj()
+ # Set the position of this object
+ mujoco_object.set("pos", array_to_string(object_pos))
+ # Add our object to the world body
+ world.worldbody.append(mujoco_object)
+
+ # add reference objects for x and y axes
+ x_ref = BoxObject(
+ name="x_ref", size=[0.01, 0.01, 0.01], rgba=[0, 1, 0, 1], obj_type="visual", joints=None
+ ).get_obj()
+ x_ref.set("pos", "0.2 0 0.105")
+ world.worldbody.append(x_ref)
+ y_ref = BoxObject(
+ name="y_ref", size=[0.01, 0.01, 0.01], rgba=[0, 0, 1, 1], obj_type="visual", joints=None
+ ).get_obj()
+ y_ref.set("pos", "0 0.2 0.105")
+ world.worldbody.append(y_ref)
+
+ self.world = world
+ self.render = render
+ self.simulation_ready = False
+ self.step_time = step_time
+ self.cur_step = 0
+ if gripper_low_pos > gripper_high_pos:
+ raise ValueError(
+ "gripper_low_pos {} is larger " "than gripper_high_pos {}".format(gripper_low_pos, gripper_high_pos)
+ )
+ self.gripper_low_pos = gripper_low_pos
+ self.gripper_high_pos = gripper_high_pos
+
+ def start_simulation(self):
+ """
+ Starts simulation of the test world
+ """
+ model_xml = self.world.get_xml()
+ self.sim = MjSim.from_xml_string(model_xml)
+
+ if self.render:
+ self.viewer = OpenCVRenderer(self.sim)
+ # We also need to add the offscreen context
+ if self.sim._render_context_offscreen is None:
+ render_context = MjRenderContextOffscreen(self.sim, device_id=-1)
+ self.sim.add_render_context(render_context)
+ self.sim_state = self.sim.get_state()
+
+ # For gravity correction
+ gravity_corrected = ["gripper_z_joint"]
+ self._gravity_corrected_qvels = [self.sim.model.get_joint_qvel_addr(x) for x in gravity_corrected]
+
+ self.gripper_z_id = self.sim.model.actuator_name2id("gripper_z")
+ self.gripper_z_is_low = False
+
+ self.gripper_actuator_ids = [self.sim.model.actuator_name2id(x) for x in self.gripper.actuators]
+
+ self.gripper_is_closed = True
+
+ self.object_id = self.sim.model.body_name2id(self.cube.root_body)
+ object_default_pos = self.sim.data.body_xpos[self.object_id]
+ self.object_default_pos = np.array(object_default_pos, copy=True)
+
+ self.reset()
+ self.simulation_ready = True
+
+ def reset(self):
+ """
+ Resets the simulation to the initial state
+ """
+ self.sim.set_state(self.sim_state)
+ self.cur_step = 0
+
+ def close(self):
+ """
+ Close the viewer if it exists
+ """
+ if self.viewer is not None:
+ self.viewer.close()
+
+ def step(self):
+ """
+ Forward the simulation by one timestep
+
+ Raises:
+ RuntimeError: if start_simulation is not yet called.
+ """
+ if not self.simulation_ready:
+ raise RuntimeError("Call start_simulation before calling step")
+ if self.gripper_z_is_low:
+ self.sim.data.ctrl[self.gripper_z_id] = self.gripper_low_pos
+ else:
+ self.sim.data.ctrl[self.gripper_z_id] = self.gripper_high_pos
+ if self.gripper_is_closed:
+ self._apply_gripper_action(1)
+ else:
+ self._apply_gripper_action(-1)
+ self._apply_gravity_compensation()
+ self.sim.step()
+ if self.render:
+ self.viewer.render()
+ self.cur_step += 1
+
+ def _apply_gripper_action(self, action):
+ """
+ Applies binary gripper action
+
+ Args:
+ action (int): Action to apply. Should be -1 (open) or 1 (closed)
+ """
+ gripper_action_actual = self.gripper.format_action(np.array([action]))
+ # rescale normalized gripper action to control ranges
+ ctrl_range = self.sim.model.actuator_ctrlrange[self.gripper_actuator_ids]
+ bias = 0.5 * (ctrl_range[:, 1] + ctrl_range[:, 0])
+ weight = 0.5 * (ctrl_range[:, 1] - ctrl_range[:, 0])
+ applied_gripper_action = bias + weight * gripper_action_actual
+ self.sim.data.ctrl[self.gripper_actuator_ids] = applied_gripper_action
+
+ def _apply_gravity_compensation(self):
+ """
+ Applies gravity compensation to the simulation
+ """
+ self.sim.data.qfrc_applied[self._gravity_corrected_qvels] = self.sim.data.qfrc_bias[
+ self._gravity_corrected_qvels
+ ]
+
+ def loop(self, total_iters=1, test_y=False, y_baseline=0.01):
+ """
+ Performs lower, grip, raise and release actions of a gripper,
+ each separated with T timesteps
+
+ Args:
+ total_iters (int): Iterations to perform before exiting
+ test_y (bool): test if object is lifted
+ y_baseline (float): threshold for determining that object is lifted
+ """
+ seq = [(False, False), (True, False), (True, True), (False, True)]
+ for cur_iter in range(total_iters):
+ for cur_plan in seq:
+ self.gripper_z_is_low, self.gripper_is_closed = cur_plan
+ for step in range(self.step_time):
+ self.step()
+ if test_y:
+ if not self.object_height > y_baseline:
+ raise ValueError(
+ "object is lifed by {}, ".format(self.object_height)
+ + "not reaching the requirement {}".format(y_baseline)
+ )
+
+ @property
+ def object_height(self):
+ """
+ Queries the height (z) of the object compared to on the ground
+
+ Returns:
+ float: Object height relative to default (ground) object position
+ """
+ return self.sim.data.body_xpos[self.object_id][2] - self.object_default_pos[2]
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/jaco_three_finger_gripper.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/jaco_three_finger_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..188412d71880d3cec60606fdedc3dd30707c9162
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/jaco_three_finger_gripper.py
@@ -0,0 +1,107 @@
+"""
+Gripper for Kinova's Jaco robot arm (has three fingers).
+"""
+import numpy as np
+
+from robosuite.models.grippers.gripper_model import GripperModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class JacoThreeFingerGripperBase(GripperModel):
+ """
+ Gripper for Kinova's Jaco robot arm (has three fingers).
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("grippers/jaco_three_finger_gripper.xml"), idn=idn)
+
+ def format_action(self, action):
+ return action
+
+ @property
+ def init_qpos(self):
+ return np.array([0.5, 0, 0.5, 0, 0.5, 0])
+
+ @property
+ def _important_geoms(self):
+ return {
+ "left_finger": [
+ "index_proximal_collision",
+ "index_distal_collision",
+ "index_tip_collision",
+ "pinky_proximal_collision",
+ "pinky_distal_collision",
+ "pinky_tip_collision",
+ "index_tip_collision",
+ "pinky_pad_collision",
+ ],
+ "right_finger": [
+ "thumb_proximal_collision",
+ "thumb_distal_collision",
+ "thumb_tip_collision",
+ "thumb_pad_collision",
+ ],
+ "left_fingerpad": ["index_pad_collision", "pinky_pad_collision"],
+ "right_fingerpad": ["thumb_pad_collision"],
+ }
+
+
+class JacoThreeFingerGripper(JacoThreeFingerGripperBase):
+ """
+ Modifies JacoThreeFingerGripperBase to only take one action.
+ """
+
+ def format_action(self, action):
+ """
+ Maps continuous action into binary output
+ -1 => open, 1 => closed
+
+ Args:
+ action (np.array): gripper-specific action
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ assert len(action) == self.dof
+ self.current_action = np.clip(self.current_action - self.speed * np.sign(action), -1.0, 1.0)
+ return self.current_action
+
+ @property
+ def speed(self):
+ return 0.005
+
+ @property
+ def dof(self):
+ return 1
+
+
+class JacoThreeFingerDexterousGripper(JacoThreeFingerGripperBase):
+ """
+ Dexterous variation of the Jaco gripper in which all finger are actuated independently
+ """
+
+ def format_action(self, action):
+ """
+ Maps continuous action into binary output
+ all -1 => open, all 1 => closed
+
+ Args:
+ action (np.array): gripper-specific action
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ assert len(action) == self.dof
+ self.current_action = np.clip(self.current_action - self.speed * np.sign(action), -1.0, 1.0)
+ return self.current_action
+
+ @property
+ def speed(self):
+ return 0.005
+
+ @property
+ def dof(self):
+ return 3
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/null_gripper.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/null_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..48f6a804744099f5cfca1dd23397d2bd7f5cd5ca
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/null_gripper.py
@@ -0,0 +1,24 @@
+"""
+Null Gripper (if we don't want to attach gripper to robot eef).
+"""
+from robosuite.models.grippers.gripper_model import GripperModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class NullGripper(GripperModel):
+ """
+ Dummy Gripper class to represent no gripper
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("grippers/null_gripper.xml"), idn=idn)
+
+ def format_action(self, action):
+ return action
+
+ @property
+ def init_qpos(self):
+ return None
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/panda_gripper.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/panda_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..64650e33d40f6df02d48032c91df3c71213d1751
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/panda_gripper.py
@@ -0,0 +1,66 @@
+"""
+Gripper for Franka's Panda (has two fingers).
+"""
+import numpy as np
+
+from robosuite.models.grippers.gripper_model import GripperModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class PandaGripperBase(GripperModel):
+ """
+ Gripper for Franka's Panda (has two fingers).
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("grippers/panda_gripper.xml"), idn=idn)
+
+ def format_action(self, action):
+ return action
+
+ @property
+ def init_qpos(self):
+ return np.array([0.020833, -0.020833])
+
+ @property
+ def _important_geoms(self):
+ return {
+ "left_finger": ["finger1_collision", "finger1_pad_collision"],
+ "right_finger": ["finger2_collision", "finger2_pad_collision"],
+ "left_fingerpad": ["finger1_pad_collision"],
+ "right_fingerpad": ["finger2_pad_collision"],
+ }
+
+
+class PandaGripper(PandaGripperBase):
+ """
+ Modifies PandaGripperBase to only take one action.
+ """
+
+ def format_action(self, action):
+ """
+ Maps continuous action into binary output
+ -1 => open, 1 => closed
+
+ Args:
+ action (np.array): gripper-specific action
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ assert len(action) == self.dof
+ self.current_action = np.clip(
+ self.current_action + np.array([-1.0, 1.0]) * self.speed * np.sign(action), -1.0, 1.0
+ )
+ return self.current_action
+
+ @property
+ def speed(self):
+ return 0.01
+
+ @property
+ def dof(self):
+ return 1
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/rethink_gripper.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/rethink_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa3852fcf1122e29682587b25643f6fdbc22cbf1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/rethink_gripper.py
@@ -0,0 +1,66 @@
+"""
+Gripper with two fingers for Rethink Robots.
+"""
+import numpy as np
+
+from robosuite.models.grippers.gripper_model import GripperModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class RethinkGripperBase(GripperModel):
+ """
+ Gripper with long two-fingered parallel jaw.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("grippers/rethink_gripper.xml"), idn=idn)
+
+ def format_action(self, action):
+ return action
+
+ @property
+ def init_qpos(self):
+ return np.array([0.020833, -0.020833])
+
+ @property
+ def _important_geoms(self):
+ return {
+ "left_finger": ["l_finger_g0", "l_finger_g1", "l_fingertip_g0", "l_fingerpad_g0"],
+ "right_finger": ["r_finger_g0", "r_finger_g1", "r_fingertip_g0", "r_fingerpad_g0"],
+ "left_fingerpad": ["l_fingerpad_g0"],
+ "right_fingerpad": ["r_fingerpad_g0"],
+ }
+
+
+class RethinkGripper(RethinkGripperBase):
+ """
+ Modifies two finger base to only take one action.
+ """
+
+ def format_action(self, action):
+ """
+ Maps continuous action into binary output
+ -1 => open, 1 => closed
+
+ Args:
+ action (np.array): gripper-specific action
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ assert len(action) == 1
+ self.current_action = np.clip(
+ self.current_action + np.array([1.0, -1.0]) * self.speed * np.sign(action), -1.0, 1.0
+ )
+ return self.current_action
+
+ @property
+ def speed(self):
+ return 0.01
+
+ @property
+ def dof(self):
+ return 1
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_140_gripper.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_140_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..38a2877e1f6f6366b3ae42f66a4c083dedb4ad09
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_140_gripper.py
@@ -0,0 +1,77 @@
+"""
+Gripper with 140mm Jaw width from Robotiq (has two fingers).
+"""
+import numpy as np
+
+from robosuite.models.grippers.gripper_model import GripperModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class Robotiq140GripperBase(GripperModel):
+ """
+ Gripper with 140mm Jaw width from Robotiq (has two fingers).
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this gripper instance
+
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("grippers/robotiq_gripper_140.xml"), idn=idn)
+
+ def format_action(self, action):
+ return action
+
+ @property
+ def init_qpos(self):
+ return np.array([0.012, 0.065, 0.065, -0.012, 0.065, 0.065])
+
+ @property
+ def _important_geoms(self):
+ return {
+ "left_finger": [
+ "left_outer_finger_collision",
+ "left_inner_finger_collision",
+ "left_fingertip_collision",
+ "left_fingerpad_collision",
+ ],
+ "right_finger": [
+ "right_outer_finger_collision",
+ "right_inner_finger_collision",
+ "right_fingertip_collision",
+ "right_fingerpad_collision",
+ ],
+ "left_fingerpad": ["left_fingerpad_collision"],
+ "right_fingerpad": ["right_fingerpad_collision"],
+ }
+
+
+class Robotiq140Gripper(Robotiq140GripperBase):
+ """
+ Modifies Robotiq140GripperBase to only take one action.
+ """
+
+ def format_action(self, action):
+ """
+ Maps continuous action into binary output
+ -1 => open, 1 => closed
+
+ Args:
+ action (np.array): gripper-specific action
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ assert len(action) == 1
+ self.current_action = np.clip(
+ self.current_action + np.array([1.0, -1.0]) * self.speed * np.sign(action), -1.0, 1.0
+ )
+ return self.current_action
+
+ @property
+ def speed(self):
+ return 0.01
+
+ @property
+ def dof(self):
+ return 1
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_85_gripper.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_85_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..690017c2c3247cf987277583fa1e7d5dd63904ee
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_85_gripper.py
@@ -0,0 +1,74 @@
+"""
+6-DoF gripper with its open/close variant
+"""
+import numpy as np
+
+from robosuite.models.grippers.gripper_model import GripperModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class Robotiq85GripperBase(GripperModel):
+ """
+ 6-DoF Robotiq gripper.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("grippers/robotiq_gripper_85_v4.xml"), idn=idn)
+
+ def format_action(self, action):
+ return action
+
+ @property
+ def init_qpos(self):
+ return np.array([-0.026, -0.267, -0.200, -0.026, -0.267, -0.200])
+
+ @property
+ def _important_geoms(self):
+ return {
+ "left_finger": [
+ "left_outer_finger_collision",
+ "left_inner_finger_collision",
+ "left_fingertip_collision",
+ "left_fingerpad_collision",
+ ],
+ "right_finger": [
+ "right_outer_finger_collision",
+ "right_inner_finger_collision",
+ "right_fingertip_collision",
+ "right_fingerpad_collision",
+ ],
+ "left_fingerpad": ["left_fingerpad_collision"],
+ "right_fingerpad": ["right_fingerpad_collision"],
+ }
+
+
+class Robotiq85Gripper(Robotiq85GripperBase):
+ """
+ 1-DoF variant of RobotiqGripperBase.
+ """
+
+ def format_action(self, action):
+ """
+ Maps continuous action into binary output
+ -1 => open, 1 => closed
+
+ Args:
+ action (np.array): gripper-specific action
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ assert len(action) == 1
+ self.current_action = np.clip(self.current_action + self.speed * np.sign(action), -1.0, 1.0)
+ return self.current_action
+
+ @property
+ def speed(self):
+ return 0.01
+
+ @property
+ def dof(self):
+ return 1
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_gripper_85_real_kinova.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_gripper_85_real_kinova.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ddfb4e988b60b177ebd9223b176d97bc1ecac17
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_gripper_85_real_kinova.py
@@ -0,0 +1,78 @@
+"""
+6-DoF gripper with its open/close variant
+"""
+import numpy as np
+
+from robosuite.models.grippers.gripper_model import GripperModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class Robotiq85GripperRealKinovaBase(GripperModel):
+ """
+ 6-DoF Robotiq gripper.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("grippers/robotiq_gripper_85_real_kinova.xml"), idn=idn)
+
+ def format_action(self, action):
+ return action
+
+ @property
+ def init_qpos(self):
+ return np.array([-0.026, -0.267, -0.200, -0.026, -0.267, -0.200])
+ # return np.array([0.00227, 0.000136, 0.00247, -0.00267, 0.00227, 0.000136, 0.00247, -0.00267])
+ # return np.array([0.00258958, 0.00264364, 0.0027039, 0.00258958, 0.00264361, 0.00270381])
+
+
+ @property
+ def _important_geoms(self):
+ return {
+ "left_finger": [
+ "left_outer_finger_collision",
+ "left_inner_finger_collision",
+ "left_fingertip_collision",
+ "left_fingerpad_collision",
+ ],
+ "right_finger": [
+ "right_outer_finger_collision",
+ "right_inner_finger_collision",
+ "right_fingertip_collision",
+ "right_fingerpad_collision",
+ ],
+ "left_fingerpad": ["left_fingerpad_collision"],
+ "right_fingerpad": ["right_fingerpad_collision"],
+ }
+
+
+class Robotiq85GripperRealKinova(Robotiq85GripperRealKinovaBase):
+ """
+ 1-DoF variant of RobotiqGripperBase.
+ """
+
+ def format_action(self, action):
+ """
+ Maps continuous action into binary output
+ -1 => open, 1 => closed
+
+ Args:
+ action (np.array): gripper-specific action
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ assert len(action) == 1
+ self.current_action = np.clip(self.current_action + self.speed * np.sign(action), -1.0, 1.0)
+ print("Modified gripper action: ", self.current_action)
+ return self.current_action
+
+ @property
+ def speed(self):
+ return 0.01
+
+ @property
+ def dof(self):
+ return 1
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_three_finger_gripper.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_three_finger_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d36d56e84dbe7636b5f4e9899dec3123615c2a2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/robotiq_three_finger_gripper.py
@@ -0,0 +1,115 @@
+"""
+Gripper with 11-DoF controlling three fingers and its open/close variant.
+"""
+import numpy as np
+
+from robosuite.models.grippers.gripper_model import GripperModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class RobotiqThreeFingerGripperBase(GripperModel):
+ """
+ Gripper with 11 dof controlling three fingers.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("grippers/robotiq_gripper_s.xml"), idn=idn)
+
+ def format_action(self, action):
+ return action
+
+ @property
+ def init_qpos(self):
+ return np.zeros(11)
+
+ @property
+ def _important_geoms(self):
+ return {
+ "left_finger": [
+ "f1_l0",
+ "f1_l1",
+ "f1_l2",
+ "f1_l3",
+ "f2_l0",
+ "f2_l1",
+ "f2_l2",
+ "f2_l3",
+ "f1_tip_collision",
+ "f2_tip_collision",
+ "f1_pad_collision",
+ "f2_pad_collision",
+ ],
+ "right_finger": [
+ "f3_l0",
+ "f3_l1",
+ "f3_l2",
+ "f3_l3",
+ "finger_middle_tip_collision",
+ "finger_middle_pad_collision",
+ ],
+ "left_fingerpad": ["f1_pad_collision", "f2_pad_collision"],
+ "right_fingerpad": ["finger_middle_pad_collision"],
+ }
+
+
+class RobotiqThreeFingerGripper(RobotiqThreeFingerGripperBase):
+ """
+ 1-DoF variant of RobotiqThreeFingerGripperBase.
+ """
+
+ def format_action(self, action):
+ """
+ Maps continuous action into binary output
+ -1 => open, 1 => closed
+
+ Args:
+ action (np.array): gripper-specific action
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ assert len(action) == self.dof
+ self.current_action = np.clip(self.current_action + self.speed * np.array(action), -1.0, 1.0)
+ # Automatically set the scissor joint to "closed" position by default
+ return np.concatenate([self.current_action * np.ones(3), [-1]])
+
+ @property
+ def speed(self):
+ return 0.01
+
+ @property
+ def dof(self):
+ return 1
+
+
+class RobotiqThreeFingerDexterousGripper(RobotiqThreeFingerGripperBase):
+ """
+ Dexterous variation of the 3-finger Robotiq gripper in which all finger are actuated independently as well
+ as the scissor joint between fingers 1 and 2
+ """
+
+ def format_action(self, action):
+ """
+ Maps continuous action into binary output
+ all -1 => open, all 1 => closed
+
+ Args:
+ action (np.array): gripper-specific action
+
+ Raises:
+ AssertionError: [Invalid action dimension size]
+ """
+ assert len(action) == self.dof
+ self.current_action = np.clip(self.current_action + self.speed * np.sign(action), -1.0, 1.0)
+ return self.current_action
+
+ @property
+ def speed(self):
+ return 0.01
+
+ @property
+ def dof(self):
+ return 4
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/grippers/wiping_gripper.py b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/wiping_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..692475efaba30837edd8ef987801124f6dfb672b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/grippers/wiping_gripper.py
@@ -0,0 +1,34 @@
+"""
+Gripper without fingers to wipe a surface
+"""
+from robosuite.models.grippers.gripper_model import GripperModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class WipingGripper(GripperModel):
+ """
+ A Wiping Gripper with no actuation and enabled with sensors to detect contact forces
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("grippers/wiping_gripper.xml"), idn=idn)
+
+ def format_action(self, action):
+ return action
+
+ @property
+ def init_qpos(self):
+ return None
+
+ @property
+ def _important_geoms(self):
+ return {
+ "left_finger": [],
+ "right_finger": [],
+ "left_fingerpad": [],
+ "right_fingerpad": [],
+ "corners": ["wiping_corner1", "wiping_corner2", "wiping_corner3", "wiping_corner4"],
+ }
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/mounts/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f92d43a51d49865a410a7899b9c41aa99085bd09
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/__init__.py
@@ -0,0 +1,15 @@
+from .mount_model import MountModel
+from .mount_factory import mount_factory
+
+from .rethink_mount import RethinkMount
+from .phantom_mount import PhantomMount
+from .null_mount import NullMount
+
+
+MOUNT_MAPPING = {
+ "RethinkMount": RethinkMount,
+ "PhantomMount": PhantomMount,
+ None: NullMount,
+}
+
+ALL_MOUNTS = MOUNT_MAPPING.keys()
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/mounts/mount_factory.py b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/mount_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..09b30335494171d2916068fa03551be36582caec
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/mount_factory.py
@@ -0,0 +1,25 @@
+"""
+Defines a string based method of initializing mounts
+"""
+
+
+def mount_factory(name, idn=0):
+ """
+ Generator for grippers
+
+ Creates a MountModel instance with the provided name.
+
+ Args:
+ name (None or str): the name of the mount class
+ idn (int or str): Number or some other unique identification string for this mount instance
+
+ Returns:
+ MountModel: requested mount instance
+
+ Raises:
+ XMLError: [invalid XML]
+ """
+ # Import MOUNT_MAPPING at runtime so we avoid circular imports
+ from robosuite.models.mounts import MOUNT_MAPPING
+
+ return MOUNT_MAPPING.get(name, "Unknown mount name: {}".format(name))(idn=idn)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/mounts/mount_model.py b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/mount_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeb9be9a35122079e4d836f3cb875dc959790e4b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/mount_model.py
@@ -0,0 +1,92 @@
+"""
+Defines the base class of all mounts
+"""
+import numpy as np
+
+from robosuite.models.base import MujocoXMLModel
+from robosuite.utils.mjcf_utils import MOUNT_COLLISION_COLOR
+
+
+class MountModel(MujocoXMLModel):
+ """
+ Base class for mounts that will be attached to robots. Note that this model's root body will be directly
+ appended to the robot's root body, so all offsets should be taken relative to that.
+
+ Args:
+ fname (str): Path to relevant xml file to create this mount instance
+ idn (int or str): Number or some other unique identification string for this gripper instance
+ """
+
+ def __init__(self, fname, idn=0):
+ super().__init__(fname, idn=idn)
+
+ # Grab mount offset (string -> np.array -> elements [1, 2, 3, 0] (x, y, z, w))
+ self.rotation_offset = np.fromstring(
+ self.worldbody[0].attrib.get("quat", "1 0 0 0"), dtype=np.float64, sep=" "
+ )[[1, 2, 3, 0]]
+
+ # -------------------------------------------------------------------------------------- #
+ # Properties: In general, these are the name-adjusted versions from the private #
+ # subclass implementations pulled from their respective raw xml files #
+ # -------------------------------------------------------------------------------------- #
+
+ @property
+ def naming_prefix(self):
+ return "mount{}_".format(self.idn)
+
+ @property
+ def _important_sites(self):
+ """
+ Returns:
+ dict: (Default is no important sites; i.e.: empty dict)
+ """
+ return {}
+
+ @property
+ def _important_geoms(self):
+ """
+ Returns:
+ dict: (Default is no important geoms; i.e.: empty dict)
+ """
+ return {}
+
+ @property
+ def _important_sensors(self):
+ """
+ Returns:
+ dict: (Default is no sensors; i.e.: empty dict)
+ """
+ return {}
+
+ @property
+ def contact_geom_rgba(self):
+ return MOUNT_COLLISION_COLOR
+
+ # -------------------------------------------------------------------------------------- #
+ # All subclasses must implement the following properties #
+ # -------------------------------------------------------------------------------------- #
+
+ @property
+ def top_offset(self):
+ """
+ Returns vector from model root body to model top.
+ This should correspond to the distance from the root body to the actual mounting surface
+ location of this mount.
+
+ Returns:
+ np.array: (dx, dy, dz) offset vector
+ """
+ raise NotImplementedError
+
+ @property
+ def horizontal_radius(self):
+ """
+ Returns maximum distance from model root body to any radial point of the model.
+
+ Helps us put models programmatically without them flying away due to a huge initial contact force.
+ Must be defined by subclass.
+
+ Returns:
+ float: radius
+ """
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/mounts/null_mount.py b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/null_mount.py
new file mode 100644
index 0000000000000000000000000000000000000000..3848e4ca2d37377780220253117ea74bea1ca769
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/null_mount.py
@@ -0,0 +1,27 @@
+"""
+Rethink's Generic Mount (Officially used on Sawyer).
+"""
+import numpy as np
+
+from robosuite.models.mounts.mount_model import MountModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class NullMount(MountModel):
+ """
+ Dummy Mount to signify no mount.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this mount instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("mounts/null_mount.xml"), idn=idn)
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, 0))
+
+ @property
+ def horizontal_radius(self):
+ return 0
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/mounts/phantom_mount.py b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/phantom_mount.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be462ea19b7e8b6d644628775c651d4d03b38cb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/phantom_mount.py
@@ -0,0 +1,28 @@
+"""
+Phantom mount.
+"""
+import numpy as np
+
+from robosuite.models.mounts.mount_model import MountModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class PhantomMount(MountModel):
+ """
+ Mount officially used for Rethink's Baxter Robot. Includes only a wheeled pedestal.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this mount instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("mounts/phantom_mount.xml"), idn=idn)
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, -0.062))
+
+ @property
+ def horizontal_radius(self):
+ # TODO: This may be inaccurate; just a placeholder for now
+ return 0.25
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/mounts/rethink_mount.py b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/rethink_mount.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed2903c91521a86268b6f8f109b45794a4dea30b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/mounts/rethink_mount.py
@@ -0,0 +1,28 @@
+"""
+Rethink's Generic Mount (Officially used on Sawyer).
+"""
+import numpy as np
+
+from robosuite.models.mounts.mount_model import MountModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class RethinkMount(MountModel):
+ """
+ Mount officially used for Rethink's Sawyer Robot. Includes a controller box and wheeled pedestal.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this mount instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("mounts/rethink_mount.xml"), idn=idn)
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, -0.01))
+
+ @property
+ def horizontal_radius(self):
+ # TODO: This may be inaccurate; just a placeholder for now
+ return 0.25
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e11ddca0521eb6d0a4bdba1e9d1582cef2fafa7c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/__init__.py
@@ -0,0 +1,24 @@
+from .objects import MujocoObject, MujocoXMLObject, MujocoGeneratedObject
+from .generated_objects import CompositeBodyObject, CompositeObject, PrimitiveObject
+from .object_groups import ObjectGroup
+
+from .xml_objects import (
+ BottleObject,
+ CanObject,
+ LemonObject,
+ MilkObject,
+ BreadObject,
+ CerealObject,
+ SquareNutObject,
+ RoundNutObject,
+ MilkVisualObject,
+ BreadVisualObject,
+ CerealVisualObject,
+ CanVisualObject,
+ PlateWithHoleObject,
+ DoorObject,
+)
+from .primitive import *
+from .composite import *
+from .composite_body import *
+from .group import *
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..67b6a445552775cb953873d0f3dddcea00dcaf2d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/__init__.py
@@ -0,0 +1,8 @@
+from .bin import Bin
+from .hammer import HammerObject
+from .lid import Lid
+from .pot_with_handles import PotWithHandlesObject
+from .hollow_cylinder import HollowCylinderObject
+from .cone import ConeObject
+from .hook_frame import HookFrame
+from .stand_with_mount import StandWithMount
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/bin.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/bin.py
new file mode 100644
index 0000000000000000000000000000000000000000..a69afc06c4892894a3fec2b01625f8cf5137a870
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/bin.py
@@ -0,0 +1,146 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.models.objects import CompositeObject
+from robosuite.utils.mjcf_utils import CustomMaterial, add_to_dict
+
+
+class Bin(CompositeObject):
+ """
+ Generates a four-walled bin container with an open top.
+ Args:
+ name (str): Name of this Bin object
+ bin_size (3-array): (x,y,z) full size of bin
+ wall_thickness (float): How thick to make walls of bin
+ transparent_walls (bool): If True, walls will be semi-translucent
+ friction (3-array or None): If specified, sets friction values for this bin. None results in default values
+ density (float): Density value to use for all geoms. Defaults to 1000
+ use_texture (bool): If true, geoms will be defined by realistic textures and rgba values will be ignored
+ rgba (4-array or None): If specified, sets rgba values for all geoms. None results in default values
+ """
+
+ def __init__(
+ self,
+ name,
+ bin_size=(0.3, 0.3, 0.15),
+ wall_thickness=0.01,
+ transparent_walls=True,
+ friction=None,
+ density=1000.0,
+ use_texture=True,
+ rgba=(0.2, 0.1, 0.0, 1.0),
+ ):
+ # Set name
+ self._name = name
+
+ # Set object attributes
+ self.bin_size = np.array(bin_size)
+ self.wall_thickness = wall_thickness
+ self.transparent_walls = transparent_walls
+ self.friction = friction if friction is None else np.array(friction)
+ self.density = density
+ self.use_texture = use_texture
+ self.rgba = rgba
+ self.bin_mat_name = "dark_wood_mat"
+
+ # Element references
+ self._base_geom = "base"
+
+ # Other private attributes
+ self._important_sites = {}
+
+ # Create dictionary of values to create geoms for composite object and run super init
+ super().__init__(**self._get_geom_attrs())
+
+ # Define materials we want to use for this object
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "3 3",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ bin_mat = CustomMaterial(
+ texture="WoodDark",
+ tex_name="dark_wood",
+ mat_name=self.bin_mat_name,
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.append_material(bin_mat)
+
+ def _get_geom_attrs(self):
+ """
+ Creates geom elements that will be passed to superclass CompositeObject constructor
+ Returns:
+ dict: args to be used by CompositeObject to generate geoms
+ """
+ # Initialize dict of obj args that we'll pass to the CompositeObject constructor
+ base_args = {
+ "total_size": self.bin_size / 2.0,
+ "name": self.name,
+ "locations_relative_to_center": True,
+ "obj_types": "all",
+ "density": self.density,
+ }
+ obj_args = {}
+
+ # Base
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(0, 0, -(self.bin_size[2] - self.wall_thickness) / 2),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=(
+ np.array((self.bin_size[0], self.bin_size[1], self.wall_thickness))
+ - np.array((self.wall_thickness, self.wall_thickness, 0))
+ )
+ / 2,
+ geom_names=self._base_geom,
+ geom_rgbas=None if self.use_texture else self.rgba,
+ geom_materials=self.bin_mat_name if self.use_texture else None,
+ geom_frictions=self.friction,
+ )
+
+ # Walls
+ x_vals = np.array(
+ [0, -(self.bin_size[0] - self.wall_thickness) / 2, 0, (self.bin_size[0] - self.wall_thickness) / 2]
+ )
+ y_vals = np.array(
+ [-(self.bin_size[1] - self.wall_thickness) / 2, 0, (self.bin_size[1] - self.wall_thickness) / 2, 0]
+ )
+ w_vals = np.array([self.bin_size[0], self.bin_size[1], self.bin_size[0], self.bin_size[1]])
+ r_vals = np.array([np.pi / 2, 0, -np.pi / 2, np.pi])
+ if self.transparent_walls:
+ wall_rgba = (1.0, 1.0, 1.0, 0.3)
+ wall_mat = None
+ else:
+ wall_rgba = None if self.use_texture else self.rgba
+ wall_mat = self.bin_mat_name if self.use_texture else None
+ for i, (x, y, w, r) in enumerate(zip(x_vals, y_vals, w_vals, r_vals)):
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(x, y, 0),
+ geom_quats=T.convert_quat(T.axisangle2quat(np.array([0, 0, r])), to="wxyz"),
+ geom_sizes=(self.wall_thickness / 2, w / 2, self.bin_size[2] / 2),
+ geom_names=f"wall{i}",
+ geom_rgbas=wall_rgba,
+ geom_materials=wall_mat,
+ geom_frictions=self.friction,
+ )
+
+ # Add back in base args and site args
+ obj_args.update(base_args)
+
+ # Return this dict
+ return obj_args
+
+ @property
+ def base_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to bin base
+ """
+ return [self.correct_naming(self._base_geom)]
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/cone.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/cone.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b35e3fe36205e4162b27addf5d9da2d50629354
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/cone.py
@@ -0,0 +1,156 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.models.objects import CompositeObject
+from robosuite.utils.mjcf_utils import RED, CustomMaterial, add_to_dict
+
+
+class ConeObject(CompositeObject):
+ """
+ Generates an approximate cone object by using cylinder or box geoms.
+ Args:
+ name (str): Name of this Cone object
+ outer_radius (float): Radius of cone base
+ inner_radius (float): Radius of cone tip (since everything is a cylinder or box)
+ height (float): Height of cone
+ ngeoms (int): Number of cylinder or box geoms used to approximate the cone. Use
+ more geoms to make the approximation better.
+ use_box (bool): If true, use box geoms instead of cylinders, corresponding to a
+ square pyramid shape instead of a conical shape.
+ """
+
+ def __init__(
+ self,
+ name,
+ outer_radius=0.0425,
+ inner_radius=0.03,
+ height=0.05,
+ ngeoms=8,
+ use_box=False,
+ rgba=None,
+ material=None,
+ density=1000.0,
+ solref=(0.02, 1.0),
+ solimp=(0.9, 0.95, 0.001),
+ friction=None,
+ ):
+
+ # Set object attributes
+ self._name = name
+ self.rgba = rgba
+ self.density = density
+ self.friction = friction if friction is None else np.array(friction)
+ self.solref = solref
+ self.solimp = solimp
+
+ self.has_material = material is not None
+ if self.has_material:
+ assert isinstance(material, CustomMaterial)
+ self.material = material
+
+ # Other private attributes
+ self._important_sites = {}
+
+ # radius of the tip and the base
+ self.r1 = inner_radius
+ self.r2 = outer_radius
+
+ # number of geoms used to approximate the cone
+ if ngeoms % 2 == 0:
+ # use an odd number of geoms for easier computation
+ ngeoms += 1
+ self.n = ngeoms
+
+ # cone height
+ self.height = height
+
+ # unit half-height for geoms
+ self.unit_height = (height / ngeoms) / 2.0
+
+ # unit radius for geom radius grid
+ self.unit_r = (self.r2 - self.r1) / (self.n - 1)
+
+ self.use_box = use_box
+
+ # Create dictionary of values to create geoms for composite object and run super init
+ super().__init__(**self._get_geom_attrs())
+
+ # Optionally add material
+ if self.has_material:
+ self.append_material(self.material)
+
+ def _get_geom_attrs(self):
+ """
+ Creates geom elements that will be passed to superclass CompositeObject constructor
+ Returns:
+ dict: args to be used by CompositeObject to generate geoms
+ """
+ # Initialize dict of obj args that we'll pass to the CompositeObject constructor
+ base_args = {
+ "total_size": [self.r2, self.r2, self.height / 2.0],
+ "name": self.name,
+ "locations_relative_to_center": True,
+ "obj_types": "all",
+ "density": self.density,
+ "solref": self.solref,
+ "solimp": self.solimp,
+ }
+ obj_args = {}
+
+ # stack the boxes / cylinders in the z-direction
+ ngeoms_each_side = (self.n - 1) // 2
+ geom_locations = [
+ (0.0, 0.0, i * self.unit_height * 2.0) for i in range(-ngeoms_each_side, ngeoms_each_side + 1)
+ ]
+
+ if self.use_box:
+ geom_sizes = [
+ (
+ self.r1 + i * self.unit_r,
+ self.r1 + i * self.unit_r,
+ self.unit_height,
+ )
+ for i in range(self.n)
+ ][::-1]
+ else:
+ geom_sizes = [
+ (
+ self.r1 + i * self.unit_r,
+ self.unit_height,
+ )
+ for i in range(self.n)
+ ][::-1]
+
+ for i in range(self.n):
+ # note: set geom condim to 4 for consistency with round-nut.xml
+ # geom_quat = np.array([np.cos(geom_angle / 2.), 0., 0., np.sin(geom_angle / 2.)])
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box" if self.use_box else "cylinder",
+ geom_locations=geom_locations[i],
+ geom_quats=None,
+ geom_sizes=geom_sizes[i],
+ geom_names="c_{}".format(i),
+ # geom_rgbas=None if self.has_material else self.rgba,
+ geom_rgbas=self.rgba,
+ geom_materials=self.material.mat_attrib["name"] if self.has_material else None,
+ geom_frictions=self.friction,
+ geom_condims=4,
+ )
+
+ # Sites
+ obj_args["sites"] = [
+ {
+ "name": "center",
+ "pos": (0, 0, 0),
+ "size": "0.002",
+ "rgba": RED,
+ "type": "sphere",
+ }
+ ]
+
+ # Add back in base args and site args
+ obj_args.update(base_args)
+
+ # Return this dict
+ return obj_args
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/hammer.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/hammer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcfa96c9d038f2197201dba5b3094526e4e69ff1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/hammer.py
@@ -0,0 +1,282 @@
+from collections.abc import Iterable
+
+import numpy as np
+
+from robosuite.models.objects import CompositeObject
+from robosuite.utils.mjcf_utils import BLUE, CYAN, GREEN, RED, CustomMaterial, add_to_dict
+
+
+class HammerObject(CompositeObject):
+ """
+ Generates a Hammer object with a cylindrical or box-shaped handle, cubic head, cylindrical face and triangular claw
+ (used in Handover task)
+
+ Args:
+ name (str): Name of this Hammer object
+
+ handle_shape (str): Either "box", for a box-shaped handle, or "cylinder", for a cylindrically-shaped handle
+
+ handle_radius (float or 2-array of float): Either specific or range of values to draw randomly from
+ uniformly for the handle radius
+
+ handle_length (float or 2-array of float): Either specific or range of values to draw randomly from
+ uniformly for the handle length
+
+ handle_density (float or 2-array of float): Either specific or range of values to draw randomly from
+ uniformly for the handle density (in SI units). Note that this value is scaled x4 for the hammer head
+
+ handle_friction (float or 2-array of float): Either specific or range of values to draw randomly from
+ uniformly for the handle friction. Note that Mujoco default values are used for the head
+
+ head_density_ratio (float): Ratio of density of handle to head (including face and claw)
+
+ use_texture (bool): If true, geoms will be defined by realistic textures and rgba values will be ignored
+
+ rgba_handle (4-array or None): If specified, sets handle rgba values
+
+ rgba_head (4-array or None): If specified, sets handle rgba values
+
+ rgba_face (4-array or None): If specified, sets handle rgba values
+
+ rgba_claw (4-array or None): If specified, sets handle rgba values
+
+ Raises:
+ ValueError: [Invalid handle shape]
+ """
+
+ def __init__(
+ self,
+ name,
+ handle_shape="box",
+ handle_radius=(0.015, 0.02),
+ handle_length=(0.1, 0.25),
+ handle_density=(100, 250),
+ handle_friction=(3.0, 5.0),
+ head_density_ratio=2.0,
+ use_texture=True,
+ rgba_handle=None,
+ rgba_head=None,
+ rgba_face=None,
+ rgba_claw=None,
+ ):
+ # Set name
+ self._name = name
+
+ # Set handle type and density ratio
+ self.handle_shape = handle_shape
+ self.head_density_ratio = head_density_ratio
+
+ # Set radius and length ranges
+ self.handle_radius_range = handle_radius if isinstance(handle_radius, Iterable) else [handle_radius] * 2
+ self.handle_length_range = handle_length if isinstance(handle_length, Iterable) else [handle_length] * 2
+ self.handle_density_range = handle_density if isinstance(handle_density, Iterable) else [handle_density] * 2
+ self.handle_friction_range = handle_friction if isinstance(handle_friction, Iterable) else [handle_friction] * 2
+
+ # Sample actual radius and length, as well as head half-size
+ self.handle_radius = np.random.uniform(self.handle_radius_range[0], self.handle_radius_range[1])
+ self.handle_length = np.random.uniform(self.handle_length_range[0], self.handle_length_range[1])
+ self.handle_density = np.random.uniform(self.handle_density_range[0], self.handle_density_range[1])
+ self.handle_friction = np.random.uniform(self.handle_friction_range[0], self.handle_friction_range[1])
+ self.head_halfsize = np.random.uniform(self.handle_radius, self.handle_radius * 1.2)
+
+ # Initialize RGBA values and texture flag
+ self.use_texture = use_texture
+ self.rgba_handle = rgba_handle if rgba_handle is not None else RED
+ self.rgba_head = rgba_head if rgba_head is not None else CYAN
+ self.rgba_face = rgba_face if rgba_face is not None else BLUE
+ self.rgba_claw = rgba_claw if rgba_claw is not None else GREEN
+
+ # Create dictionary of values to create geoms for composite object and run super init
+ super().__init__(**self._get_geom_attrs())
+
+ # Define materials we want to use for this object
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "3 3",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ metal = CustomMaterial(
+ texture="SteelScratched",
+ tex_name="metal",
+ mat_name="metal_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ wood = CustomMaterial(
+ texture="WoodLight",
+ tex_name="wood",
+ mat_name="wood_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+
+ # Append materials to object
+ self.append_material(metal)
+ self.append_material(wood)
+
+ def _get_geom_attrs(self):
+ """
+ Creates geom elements that will be passed to superclass CompositeObject constructor
+
+ Returns:
+ dict: args to be used by CompositeObject to generate geoms
+ """
+ full_size = np.array(
+ (3.2 * self.head_halfsize, self.head_halfsize, self.handle_length + 2 * self.head_halfsize)
+ )
+ # Initialize dict of obj args that we'll pass to the CompositeObject constructor
+ base_args = {
+ "total_size": full_size / 2.0,
+ "name": self.name,
+ "locations_relative_to_center": True,
+ "obj_types": "all",
+ }
+ obj_args = {}
+
+ # Add handle component
+ assert self.handle_shape in {
+ "cylinder",
+ "box",
+ }, "Error loading hammer: Handle type must either be 'box' or 'cylinder', got {}.".format(self.handle_shape)
+ add_to_dict(
+ dic=obj_args,
+ geom_types="cylinder" if self.handle_shape == "cylinder" else "box",
+ geom_locations=(0, 0, 0),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=np.array([self.handle_radius, self.handle_length / 2.0])
+ if self.handle_shape == "cylinder"
+ else np.array([self.handle_radius, self.handle_radius, self.handle_length / 2.0]),
+ geom_names="handle",
+ geom_rgbas=None if self.use_texture else self.rgba_handle,
+ geom_materials="wood_mat" if self.use_texture else None,
+ geom_frictions=(self.handle_friction, 0.005, 0.0001),
+ density=self.handle_density,
+ )
+
+ # Add head component
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(0, 0, self.handle_length / 2.0 + self.head_halfsize),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=np.array([self.head_halfsize * 2, self.head_halfsize, self.head_halfsize]),
+ geom_names="head",
+ geom_rgbas=None if self.use_texture else self.rgba_head,
+ geom_materials="metal_mat" if self.use_texture else None,
+ geom_frictions=None,
+ density=self.handle_density * self.head_density_ratio,
+ )
+
+ # Add neck component
+ add_to_dict(
+ dic=obj_args,
+ geom_types="cylinder",
+ geom_locations=(self.head_halfsize * 2.2, 0, self.handle_length / 2.0 + self.head_halfsize),
+ geom_quats=(0.707106, 0, 0.707106, 0),
+ geom_sizes=np.array([self.head_halfsize * 0.8, self.head_halfsize * 0.2]),
+ geom_names="neck",
+ geom_rgbas=None if self.use_texture else self.rgba_face,
+ geom_materials="metal_mat" if self.use_texture else None,
+ geom_frictions=None,
+ density=self.handle_density * self.head_density_ratio,
+ )
+
+ # Add face component
+ add_to_dict(
+ dic=obj_args,
+ geom_types="cylinder",
+ geom_locations=(self.head_halfsize * 2.8, 0, self.handle_length / 2.0 + self.head_halfsize),
+ geom_quats=(0.707106, 0, 0.707106, 0),
+ geom_sizes=np.array([self.head_halfsize, self.head_halfsize * 0.4]),
+ geom_names="face",
+ geom_rgbas=None if self.use_texture else self.rgba_face,
+ geom_materials="metal_mat" if self.use_texture else None,
+ geom_frictions=None,
+ density=self.handle_density * self.head_density_ratio,
+ )
+
+ # Add claw component
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(-self.head_halfsize * 2, 0, self.handle_length / 2.0 + self.head_halfsize),
+ geom_quats=(0.9238795, 0, 0.3826834, 0),
+ geom_sizes=np.array([self.head_halfsize * 0.7072, self.head_halfsize * 0.95, self.head_halfsize * 0.7072]),
+ geom_names="claw",
+ geom_rgbas=None if self.use_texture else self.rgba_claw,
+ geom_materials="metal_mat" if self.use_texture else None,
+ geom_frictions=None,
+ density=self.handle_density * self.head_density_ratio,
+ )
+
+ # Add back in base args
+ obj_args.update(base_args)
+
+ # Return this dict
+ return obj_args
+
+ @property
+ def init_quat(self):
+ """
+ Generates a new random orientation for the hammer
+
+ Returns:
+ np.array: (x, y, z, w) quaternion orientation for the hammer
+ """
+ # Randomly sample between +/- flip (such that the hammer head faces one way or the other)
+ return np.array([0.5, -0.5, 0.5, -0.5]) if np.random.rand() >= 0.5 else np.array([-0.5, -0.5, -0.5, -0.5])
+
+ @property
+ def handle_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to hammer handle
+ """
+ return self.correct_naming(["handle"])
+
+ @property
+ def head_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to hammer head
+ """
+ return self.correct_naming(["head"])
+
+ @property
+ def face_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to hammer face
+ """
+ return self.correct_naming(["neck", "face"])
+
+ @property
+ def claw_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to hammer claw
+ """
+ return self.correct_naming(["claw"])
+
+ @property
+ def all_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to all hammer components
+ """
+ return self.handle_geoms + self.head_geoms + self.face_geoms + self.claw_geoms
+
+ @property
+ def bottom_offset(self):
+ return np.array([0, 0, -self.handle_radius])
+
+ @property
+ def top_offset(self):
+ return np.array([0, 0, self.handle_radius])
+
+ @property
+ def horizontal_radius(self):
+ return self.head_halfsize + 0.5 * self.handle_length
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/hollow_cylinder.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/hollow_cylinder.py
new file mode 100644
index 0000000000000000000000000000000000000000..329dba0b045bfd99390fa269d34625785f845934
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/hollow_cylinder.py
@@ -0,0 +1,146 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.models.objects import CompositeObject
+from robosuite.utils.mjcf_utils import RED, CustomMaterial, add_to_dict
+
+
+class HollowCylinderObject(CompositeObject):
+ """
+ Generates an approximate hollow cylinder object by using box geoms.
+ Args:
+ name (str): Name of this HollowCylinder object
+ outer_radius (float): Outer radius of hollow cylinder
+ inner_radius (float): Inner radius of hollow cylinder
+ height (float): Height of hollow cylinder
+ ngeoms (int): Number of box geoms used to approximate the cylindrical shell. Use
+ more geoms to make the approximation better.
+ make_half (bool): If true, only make half of the shell.
+ """
+
+ def __init__(
+ self,
+ name,
+ outer_radius=0.0425,
+ inner_radius=0.03,
+ height=0.05,
+ ngeoms=8,
+ rgba=None,
+ material=None,
+ density=1000.0,
+ solref=(0.02, 1.0),
+ solimp=(0.9, 0.95, 0.001),
+ friction=None,
+ make_half=False,
+ ):
+
+ # Set object attributes
+ self._name = name
+ self.rgba = rgba
+ self.density = density
+ self.friction = friction if friction is None else np.array(friction)
+ self.solref = solref
+ self.solimp = solimp
+ self.make_half = make_half # if True, will only make half the hollow cylinder
+
+ self.has_material = material is not None
+ if self.has_material:
+ assert isinstance(material, CustomMaterial)
+ self.material = material
+
+ # Other private attributes
+ self._important_sites = {}
+
+ # radius of the inner cup hole and entire cup
+ self.r1 = inner_radius
+ self.r2 = outer_radius
+
+ # number of geoms used to approximate the cylindrical shell
+ self.n = ngeoms
+
+ # cylinder half-height
+ self.height = height
+
+ # half-width of each box inferred from triangle of radius + box half-length
+ # since the angle will be (360 / n) / 2
+ self.unit_box_width = self.r2 * np.sin(np.pi / self.n)
+
+ # half-height of each box inferred from the same triangle with inner radius
+ self.unit_box_height = (self.r2 - self.r1) * np.cos(np.pi / self.n) / 2.0
+
+ # each box geom depth will end up defining the height of the cup
+ self.unit_box_depth = self.height
+
+ # radius of intermediate circle that connects all box centers
+ self.int_r = (self.r1 * np.cos(np.pi / self.n)) + self.unit_box_height
+
+ # Create dictionary of values to create geoms for composite object and run super init
+ super().__init__(**self._get_geom_attrs())
+
+ # Optionally add material
+ if self.has_material:
+ self.append_material(self.material)
+
+ def _get_geom_attrs(self):
+ """
+ Creates geom elements that will be passed to superclass CompositeObject constructor
+ Returns:
+ dict: args to be used by CompositeObject to generate geoms
+ """
+ # Initialize dict of obj args that we'll pass to the CompositeObject constructor
+ base_args = {
+ "total_size": [self.r2, self.r2, self.height],
+ "name": self.name,
+ "locations_relative_to_center": True,
+ "obj_types": "all",
+ "density": self.density,
+ "solref": self.solref,
+ "solimp": self.solimp,
+ }
+ obj_args = {}
+
+ n_make = self.n
+ if self.make_half:
+ # only make half the shell
+ n_make = (self.n // 2) + 1
+
+ # infer locations of all geoms with trigonometry
+ angle_step = 2.0 * np.pi / self.n
+ for i in range(n_make):
+ # we start with the top-most box object and proceed clockwise (thus an offset of np.pi)
+ geom_angle = np.pi - i * angle_step
+ geom_center = np.array([self.int_r * np.cos(geom_angle), self.int_r * np.sin(geom_angle), 0.0])
+ geom_quat = np.array([np.cos(geom_angle / 2.0), 0.0, 0.0, np.sin(geom_angle / 2.0)])
+ geom_size = np.array([self.unit_box_height, self.unit_box_width, self.unit_box_depth])
+
+ # note: set geom condim to 4 for consistency with round-nut.xml
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=tuple(geom_center),
+ geom_quats=tuple(geom_quat),
+ geom_sizes=tuple(geom_size),
+ geom_names="hc_{}".format(i),
+ # geom_rgbas=None if self.has_material else self.rgba,
+ geom_rgbas=self.rgba,
+ geom_materials=self.material.mat_attrib["name"] if self.has_material else None,
+ geom_frictions=self.friction,
+ geom_condims=4,
+ )
+
+ # Sites
+ obj_args["sites"] = [
+ {
+ "name": "center",
+ "pos": (0, 0, 0),
+ "size": "0.002",
+ "rgba": RED,
+ "type": "sphere",
+ }
+ ]
+
+ # Add back in base args and site args
+ obj_args.update(base_args)
+
+ # Return this dict
+ return obj_args
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/hook_frame.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/hook_frame.py
new file mode 100644
index 0000000000000000000000000000000000000000..3741c7b74dac78a1963da1c2c01d7023fc9c85d5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/hook_frame.py
@@ -0,0 +1,332 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.models.objects import CompositeObject
+from robosuite.utils.mjcf_utils import BLUE, GREEN, RED, CustomMaterial, add_to_dict
+
+
+class HookFrame(CompositeObject):
+ """
+ Generates an upside down L-shaped frame (a "hook" shape), intended to be used with StandWithMount object.
+ Args:
+ name (str): Name of this object
+ frame_length (float): How long the frame is
+ frame_height (float): How tall the frame is
+ frame_thickness (float): How thick the frame is
+ hook_height (float): if not None, add a box geom at the edge of the hook with this height (not half-height)
+ grip_location (float): if not None, adds a grip to passed location, relative to center of the rod corresponding to @frame_height.
+ grip_size ([float]): (R, H) radius and half-height for the cylindrical grip. Set to None
+ to not add a grip.
+ tip_size ([float]): if not None, adds a cone tip to the end of the hook for easier insertion, with the
+ provided (CH, LR, UR, H) where CH is the base cylinder height, LR and UR are the lower and upper radius
+ of the cone tip, and H is the half-height of the cone tip
+ friction (3-array or None): If specified, sets friction values for this object. None results in default values
+ density (float): Density value to use for all geoms. Defaults to 1000
+ use_texture (bool): If true, geoms will be defined by realistic textures and rgba values will be ignored
+ rgba (4-array or None): If specified, sets rgba values for all geoms. None results in default values
+ """
+
+ def __init__(
+ self,
+ name,
+ frame_length=0.3,
+ frame_height=0.2,
+ frame_thickness=0.025,
+ hook_height=None,
+ grip_location=None,
+ grip_size=None,
+ tip_size=None,
+ friction=None,
+ density=1000.0,
+ solref=(0.02, 1.0),
+ solimp=(0.9, 0.95, 0.001),
+ use_texture=True,
+ rgba=(0.2, 0.1, 0.0, 1.0),
+ ):
+ # Set name
+ self._name = name
+
+ # Set object attributes
+ self.size = None # Filled in automatically
+ self.frame_length = frame_length
+ self.frame_height = frame_height
+ self.frame_thickness = frame_thickness
+ self.hook_height = hook_height
+ self.grip_location = grip_location
+ self.grip_size = tuple(grip_size) if grip_size is not None else None
+ self.tip_size = tuple(tip_size) if tip_size is not None else None
+ self.friction = friction if friction is None else np.array(friction)
+ self.solref = solref
+ self.solimp = solimp
+ self.density = density
+ self.use_texture = use_texture
+ self.rgba = rgba
+ self.mat_name = "brass_mat"
+ self.grip_mat_name = "ceramic_mat"
+ self.tip_mat_name = "steel_mat"
+
+ # Other private attributes
+ self._important_sites = {}
+
+ # Create dictionary of values to create geoms for composite object and run super init
+ super().__init__(**self._get_geom_attrs())
+
+ # Define materials we want to use for this object
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "3 3",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ bin_mat = CustomMaterial(
+ texture="Brass",
+ tex_name="brass",
+ mat_name=self.mat_name,
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.append_material(bin_mat)
+ # optionally add material for grip
+ if (self.grip_location is not None) and (self.grip_size is not None):
+ grip_mat = CustomMaterial(
+ texture="Ceramic",
+ tex_name="ceramic",
+ mat_name=self.grip_mat_name,
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.append_material(grip_mat)
+ # optionally add material for tip
+ if self.tip_size is not None:
+ tip_mat = CustomMaterial(
+ texture="SteelScratched",
+ tex_name="steel",
+ mat_name=self.tip_mat_name,
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.append_material(tip_mat)
+
+ def _get_geom_attrs(self):
+ """
+ Creates geom elements that will be passed to superclass CompositeObject constructor
+ Returns:
+ dict: args to be used by CompositeObject to generate geoms
+ """
+ # Initialize dict of obj args that we'll pass to the CompositeObject constructor
+ self.size = np.array((self.frame_length, self.frame_thickness, self.frame_height))
+ if self.tip_size is not None:
+ self.size[2] += 2.0 * (self.tip_size[0] + (2.0 * self.tip_size[3]))
+ base_args = {
+ "total_size": self.size / 2,
+ "name": self.name,
+ "locations_relative_to_center": True,
+ "obj_types": "all",
+ "density": self.density,
+ "solref": self.solref,
+ "solimp": self.solimp,
+ }
+ obj_args = {}
+
+ # Vertical Frame
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=((self.frame_length - self.frame_thickness) / 2, 0, -self.frame_thickness / 2),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=np.array((self.frame_thickness, self.frame_thickness, self.frame_height - self.frame_thickness))
+ / 2,
+ geom_names="vertical_frame",
+ geom_rgbas=None if self.use_texture else self.rgba,
+ geom_materials=self.mat_name if self.use_texture else None,
+ geom_frictions=self.friction,
+ )
+
+ # Horizontal Frame
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(0, 0, (self.frame_height - self.frame_thickness) / 2),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=np.array((self.frame_length, self.frame_thickness, self.frame_thickness)) / 2,
+ geom_names="horizontal_frame",
+ geom_rgbas=None if self.use_texture else self.rgba,
+ geom_materials=self.mat_name if self.use_texture else None,
+ geom_frictions=self.friction,
+ )
+
+ # optionally add hook at the end of the horizontal frame
+ if self.hook_height is not None:
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(
+ (-self.frame_length + self.frame_thickness) / 2,
+ 0,
+ (self.frame_height + self.hook_height) / 2,
+ ),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=np.array((self.frame_thickness, self.frame_thickness, self.hook_height)) / 2,
+ geom_names="hook_frame",
+ geom_rgbas=None if self.use_texture else self.rgba,
+ geom_materials=self.mat_name if self.use_texture else None,
+ geom_frictions=self.friction,
+ )
+
+ # optionally add a grip
+ if (self.grip_location is not None) and (self.grip_size is not None):
+ # note: use box grip instead of cylindrical grip for stability
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(
+ (self.frame_length - self.frame_thickness) / 2,
+ 0,
+ (-self.frame_thickness / 2) + self.grip_location,
+ ),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=(self.grip_size[0], self.grip_size[0], self.grip_size[1]),
+ geom_names="grip_frame",
+ # geom_rgbas=None if self.use_texture else self.rgba,
+ geom_rgbas=(0.13, 0.13, 0.13, 1.0),
+ geom_materials=self.grip_mat_name if self.use_texture else None,
+ # geom_frictions=self.friction,
+ geom_frictions=(1.0, 0.005, 0.0001), # use default friction
+ )
+
+ # optionally add cone tip
+ if self.tip_size is not None:
+ from robosuite.models.objects import ConeObject
+
+ cone = ConeObject(
+ name="cone",
+ outer_radius=self.tip_size[2],
+ inner_radius=self.tip_size[1],
+ height=self.tip_size[3],
+ # ngeoms=8,
+ ngeoms=50,
+ use_box=True,
+ # use_box=False,
+ rgba=None,
+ material=None,
+ density=self.density,
+ solref=self.solref,
+ solimp=self.solimp,
+ friction=self.friction,
+ )
+ cone_args = cone._get_geom_attrs()
+
+ # DIRTY HACK: add them in reverse (in hindsight, should just turn this into a composite body...)
+ cone_geom_types = cone_args["geom_types"]
+ cone_geom_locations = cone_args["geom_locations"]
+ cone_geom_sizes = cone_args["geom_sizes"][::-1]
+
+ # location of mount site is the translation we need
+ cylinder_offset = (
+ (self.frame_length - self.frame_thickness) / 2,
+ 0,
+ -self.frame_height / 2 - self.tip_size[0], # account for half-height of cylinder
+ )
+ cone_offset = (
+ cylinder_offset[0],
+ cylinder_offset[1],
+ cylinder_offset[2]
+ - self.tip_size[0]
+ - self.tip_size[3] / 2.0, # need to move below cylinder, and account for half-height
+ )
+
+ # first add cylinder
+ add_to_dict(
+ dic=obj_args,
+ geom_types="cylinder",
+ geom_locations=cylinder_offset,
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=(self.tip_size[2], self.tip_size[0]),
+ geom_names="tip_cylinder",
+ geom_rgbas=None if self.use_texture else self.rgba,
+ geom_materials=self.tip_mat_name if self.use_texture else None,
+ geom_frictions=self.friction,
+ )
+
+ # then add cone tip geoms
+ for i in range(len(cone_geom_types)):
+ add_to_dict(
+ dic=obj_args,
+ geom_types=cone_geom_types[i],
+ geom_locations=(
+ cone_geom_locations[i][0] + cone_offset[0],
+ cone_geom_locations[i][1] + cone_offset[1],
+ cone_geom_locations[i][2] + cone_offset[2],
+ ),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=cone_geom_sizes[i],
+ geom_names="tip_cone_{}".format(i),
+ geom_rgbas=None if self.use_texture else self.rgba,
+ geom_materials=self.tip_mat_name if self.use_texture else None,
+ geom_frictions=self.friction,
+ )
+
+ # Sites
+ obj_args["sites"] = [
+ {
+ "name": f"hang_site",
+ "pos": (-self.frame_length / 2, 0, (self.frame_height - self.frame_thickness) / 2),
+ "size": "0.002",
+ "rgba": RED,
+ "type": "sphere",
+ },
+ {
+ "name": f"mount_site",
+ "pos": ((self.frame_length - self.frame_thickness) / 2, 0, -self.frame_height / 2),
+ "size": "0.002",
+ "rgba": GREEN,
+ "type": "sphere",
+ },
+ {
+ "name": f"intersection_site",
+ "pos": (
+ (self.frame_length - self.frame_thickness) / 2,
+ 0,
+ (self.frame_height - self.frame_thickness) / 2,
+ ),
+ "size": "0.002",
+ "rgba": BLUE,
+ "type": "sphere",
+ },
+ ]
+
+ if self.tip_size is not None:
+ obj_args["sites"].append(
+ {
+ "name": f"tip_site",
+ "pos": (
+ ((self.frame_length - self.frame_thickness) / 2),
+ 0,
+ (-self.frame_height / 2) - 2.0 * self.tip_size[0] - self.tip_size[3],
+ ),
+ "size": "0.002",
+ "rgba": RED,
+ "type": "sphere",
+ },
+ )
+
+ # Add back in base args and site args
+ obj_args.update(base_args)
+
+ # Return this dict
+ return obj_args
+
+ @property
+ def init_quat(self):
+ """
+ Rotate the frame on its side so it is flat
+ Returns:
+ np.array: (x, y, z, w) quaternion orientation for this object
+ """
+ # Rotate 90 degrees about two consecutive axes to make the hook lie on the table instead of being upright.
+ return T.quat_multiply(
+ np.array([0, 0.0, np.sqrt(2) / 2.0, np.sqrt(2) / 2.0]),
+ np.array([-np.sqrt(2) / 2.0, 0.0, 0.0, np.sqrt(2) / 2.0]),
+ )
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/lid.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/lid.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eb602bd9308bbe387e4648939d61c00c32f5520
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/lid.py
@@ -0,0 +1,136 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.models.objects import CompositeObject
+from robosuite.utils.mjcf_utils import CustomMaterial, add_to_dict
+
+
+class Lid(CompositeObject):
+ """
+ Generates a square lid with a simple handle.
+ Args:
+ name (str): Name of this Lid object
+ lid_size (3-array): (length, width, thickness) of lid
+ handle_size (3-array): (thickness, length, height) of handle
+ transparent (bool): If True, lid will be semi-translucent
+ friction (3-array or None): If specified, sets friction values for this lid. None results in default values
+ density (float): Density value to use for all geoms. Defaults to 1000
+ use_texture (bool): If true, geoms will be defined by realistic textures and rgba values will be ignored
+ rgba (4-array or None): If specified, sets rgba values for all geoms. None results in default values
+ """
+
+ def __init__(
+ self,
+ name,
+ lid_size=(0.3, 0.3, 0.01),
+ handle_size=(0.02, 0.08, 0.03),
+ transparent=True,
+ friction=None,
+ density=250.0,
+ use_texture=True,
+ rgba=(0.2, 0.1, 0.0, 1.0),
+ ):
+ # Set name
+ self._name = name
+
+ # Set object attributes
+ self.lid_size = np.array(lid_size)
+ self.handle_size = np.array(handle_size)
+ self.transparent = transparent
+ self.friction = friction if friction is None else np.array(friction)
+ self.density = density
+ self.use_texture = use_texture
+ self.rgba = rgba
+ self.lid_mat_name = "dark_wood_mat"
+
+ # Element references
+ self._handle_geom = "handle"
+
+ # Other private attributes
+ self._important_sites = {}
+
+ # Create dictionary of values to create geoms for composite object and run super init
+ super().__init__(**self._get_geom_attrs())
+
+ # Define materials we want to use for this object
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "3 3",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ lid_mat = CustomMaterial(
+ texture="WoodDark",
+ tex_name="dark_wood",
+ mat_name=self.lid_mat_name,
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.append_material(lid_mat)
+
+ def _get_geom_attrs(self):
+ """
+ Creates geom elements that will be passed to superclass CompositeObject constructor
+ Returns:
+ dict: args to be used by CompositeObject to generate geoms
+ """
+ full_height = self.lid_size[2] + self.handle_size[2]
+ full_size = np.array([self.lid_size[0], self.lid_size[1], full_height])
+ # Initialize dict of obj args that we'll pass to the CompositeObject constructor
+ base_args = {
+ "total_size": full_size / 2.0,
+ "name": self.name,
+ "locations_relative_to_center": True,
+ "obj_types": "all",
+ }
+ obj_args = {}
+
+ # Top
+ if self.transparent:
+ top_rgba = (1.0, 1.0, 1.0, 0.3)
+ top_mat = None
+ else:
+ top_rgba = None if self.use_texture else self.rgba
+ top_mat = self.lid_mat_name if self.use_texture else None
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(0, 0, (-full_size[2] + self.lid_size[2]) / 2),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=np.array((full_size[0], full_size[1], self.lid_size[2])) / 2,
+ geom_names="top",
+ geom_rgbas=top_rgba,
+ geom_materials=top_mat,
+ geom_frictions=self.friction,
+ density=self.density,
+ )
+
+ # Handle
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(0, 0, (full_size[2] - self.handle_size[2]) / 2),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=self.handle_size / 2,
+ geom_names=self._handle_geom,
+ geom_rgbas=None if self.use_texture else self.rgba,
+ geom_materials=self.lid_mat_name if self.use_texture else None,
+ geom_frictions=self.friction,
+ density=self.density * 2,
+ )
+
+ # Add back in base args and site args
+ obj_args.update(base_args)
+
+ # Return this dict
+ return obj_args
+
+ @property
+ def handle_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to lid handle
+ """
+ return [self.correct_naming(self._handle_geom)]
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/pot_with_handles.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/pot_with_handles.py
new file mode 100644
index 0000000000000000000000000000000000000000..6783e5ebe9fc74a7c08c2f44536e6ecdeb73a9d7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/pot_with_handles.py
@@ -0,0 +1,350 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.models.objects import CompositeObject
+from robosuite.utils.mjcf_utils import BLUE, GREEN, RED, CustomMaterial, add_to_dict, array_to_string
+
+
+class PotWithHandlesObject(CompositeObject):
+ """
+ Generates the Pot object with side handles (used in TwoArmLift)
+
+ Args:
+ name (str): Name of this Pot object
+
+ body_half_size (3-array of float): If specified, defines the (x,y,z) half-dimensions of the main pot
+ body. Otherwise, defaults to [0.07, 0.07, 0.07]
+
+ handle_radius (float): Determines the pot handle radius
+
+ handle_length (float): Determines the pot handle length
+
+ handle_width (float): Determines the pot handle width
+
+ handle_friction (float): Friction value to use for pot handles. Defauls to 1.0
+
+ density (float): Density value to use for all geoms. Defaults to 1000
+
+ use_texture (bool): If true, geoms will be defined by realistic textures and rgba values will be ignored
+
+ rgba_body (4-array or None): If specified, sets pot body rgba values
+
+ rgba_handle_0 (4-array or None): If specified, sets handle 0 rgba values
+
+ rgba_handle_1 (4-array or None): If specified, sets handle 1 rgba values
+
+ solid_handle (bool): If true, uses a single geom to represent the handle
+
+ thickness (float): How thick to make the pot body walls
+ """
+
+ def __init__(
+ self,
+ name,
+ body_half_size=(0.07, 0.07, 0.07),
+ handle_radius=0.01,
+ handle_length=0.09,
+ handle_width=0.09,
+ handle_friction=1.0,
+ density=1000,
+ use_texture=True,
+ rgba_body=None,
+ rgba_handle_0=None,
+ rgba_handle_1=None,
+ solid_handle=False,
+ thickness=0.01, # For body
+ ):
+ # Set name
+ self._name = name
+
+ # Set object attributes
+ self.body_half_size = np.array(body_half_size)
+ self.thickness = thickness
+ self.handle_radius = handle_radius
+ self.handle_length = handle_length
+ self.handle_width = handle_width
+ self.handle_friction = handle_friction
+ self.density = density
+ self.use_texture = use_texture
+ self.rgba_body = np.array(rgba_body) if rgba_body else RED
+ self.rgba_handle_0 = np.array(rgba_handle_0) if rgba_handle_0 else GREEN
+ self.rgba_handle_1 = np.array(rgba_handle_1) if rgba_handle_1 else BLUE
+ self.solid_handle = solid_handle
+
+ # Element references to be filled when generated
+ self._handle0_geoms = None
+ self._handle1_geoms = None
+ self.pot_base = None
+
+ # Other private attributes
+ self._important_sites = {}
+
+ # Create dictionary of values to create geoms for composite object and run super init
+ super().__init__(**self._get_geom_attrs())
+
+ # Define materials we want to use for this object
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "1 1",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ redwood = CustomMaterial(
+ texture="WoodRed",
+ tex_name="redwood",
+ mat_name="pot_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ greenwood = CustomMaterial(
+ texture="WoodGreen",
+ tex_name="greenwood",
+ mat_name="handle0_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ bluewood = CustomMaterial(
+ texture="WoodBlue",
+ tex_name="bluewood",
+ mat_name="handle1_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.append_material(redwood)
+ self.append_material(greenwood)
+ self.append_material(bluewood)
+
+ def _get_geom_attrs(self):
+ """
+ Creates geom elements that will be passed to superclass CompositeObject constructor
+
+ Returns:
+ dict: args to be used by CompositeObject to generate geoms
+ """
+ full_size = np.array(
+ (
+ self.body_half_size,
+ self.body_half_size + self.handle_length * 2,
+ self.body_half_size,
+ )
+ )
+ # Initialize dict of obj args that we'll pass to the CompositeObject constructor
+ base_args = {
+ "total_size": full_size / 2.0,
+ "name": self.name,
+ "locations_relative_to_center": True,
+ "obj_types": "all",
+ }
+ site_attrs = []
+ obj_args = {}
+
+ # Initialize geom lists
+ self._handle0_geoms = []
+ self._handle1_geoms = []
+
+ # Add main pot body
+ # Base geom
+ name = f"base"
+ self.pot_base = [name]
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(0, 0, -self.body_half_size[2] + self.thickness / 2),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=np.array([self.body_half_size[0], self.body_half_size[1], self.thickness / 2]),
+ geom_names=name,
+ geom_rgbas=None if self.use_texture else self.rgba_body,
+ geom_materials="pot_mat" if self.use_texture else None,
+ geom_frictions=None,
+ density=self.density,
+ )
+
+ # Walls
+ x_off = np.array(
+ [0, -(self.body_half_size[0] - self.thickness / 2), 0, self.body_half_size[0] - self.thickness / 2]
+ )
+ y_off = np.array(
+ [-(self.body_half_size[1] - self.thickness / 2), 0, self.body_half_size[1] - self.thickness / 2, 0]
+ )
+ w_vals = np.array(
+ [self.body_half_size[1], self.body_half_size[0], self.body_half_size[1], self.body_half_size[0]]
+ )
+ r_vals = np.array([np.pi / 2, 0, -np.pi / 2, np.pi])
+ for i, (x, y, w, r) in enumerate(zip(x_off, y_off, w_vals, r_vals)):
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(x, y, 0),
+ geom_quats=T.convert_quat(T.axisangle2quat(np.array([0, 0, r])), to="wxyz"),
+ geom_sizes=np.array([self.thickness / 2, w, self.body_half_size[2]]),
+ geom_names=f"body{i}",
+ geom_rgbas=None if self.use_texture else self.rgba_body,
+ geom_materials="pot_mat" if self.use_texture else None,
+ geom_frictions=None,
+ density=self.density,
+ )
+
+ # Add handles
+ main_bar_size = np.array(
+ [
+ self.handle_width / 2 + self.handle_radius,
+ self.handle_radius,
+ self.handle_radius,
+ ]
+ )
+ side_bar_size = np.array([self.handle_radius, self.handle_length / 2, self.handle_radius])
+ handle_z = self.body_half_size[2] - self.handle_radius
+ for i, (g_list, handle_side, rgba) in enumerate(
+ zip([self._handle0_geoms, self._handle1_geoms], [1.0, -1.0], [self.rgba_handle_0, self.rgba_handle_1])
+ ):
+ handle_center = np.array((0, handle_side * (self.body_half_size[1] + self.handle_length), handle_z))
+ # Solid handle case
+ if self.solid_handle:
+ name = f"handle{i}"
+ g_list.append(name)
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=handle_center,
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=np.array([self.handle_width / 2, self.handle_length / 2, self.handle_radius]),
+ geom_names=name,
+ geom_rgbas=None if self.use_texture else rgba,
+ geom_materials=f"handle{i}_mat" if self.use_texture else None,
+ geom_frictions=(self.handle_friction, 0.005, 0.0001),
+ density=self.density,
+ )
+ # Hollow handle case
+ else:
+ # Center bar
+ name = f"handle{i}_c"
+ g_list.append(name)
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=handle_center,
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=main_bar_size,
+ geom_names=name,
+ geom_rgbas=None if self.use_texture else rgba,
+ geom_materials=f"handle{i}_mat" if self.use_texture else None,
+ geom_frictions=(self.handle_friction, 0.005, 0.0001),
+ density=self.density,
+ )
+ # Side bars
+ for bar_side, suffix in zip([-1.0, 1.0], ["-", "+"]):
+ name = f"handle{i}_{suffix}"
+ g_list.append(name)
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(
+ bar_side * self.handle_width / 2,
+ handle_side * (self.body_half_size[1] + self.handle_length / 2),
+ handle_z,
+ ),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=side_bar_size,
+ geom_names=name,
+ geom_rgbas=None if self.use_texture else rgba,
+ geom_materials=f"handle{i}_mat" if self.use_texture else None,
+ geom_frictions=(self.handle_friction, 0.005, 0.0001),
+ density=self.density,
+ )
+ # Add relevant site
+ handle_site = self.get_site_attrib_template()
+ handle_name = f"handle{i}"
+ handle_site.update(
+ {
+ "name": handle_name,
+ "pos": array_to_string(handle_center - handle_side * np.array([0, 0.005, 0])),
+ "size": "0.005",
+ "rgba": rgba,
+ }
+ )
+ site_attrs.append(handle_site)
+ # Add to important sites
+ self._important_sites[f"handle{i}"] = self.naming_prefix + handle_name
+
+ # Add pot body site
+ pot_site = self.get_site_attrib_template()
+ center_name = "center"
+ pot_site.update(
+ {
+ "name": center_name,
+ "size": "0.005",
+ }
+ )
+ site_attrs.append(pot_site)
+ # Add to important sites
+ self._important_sites["center"] = self.naming_prefix + center_name
+
+ # Add back in base args and site args
+ obj_args.update(base_args)
+ obj_args["sites"] = site_attrs # All sites are part of main (top) body
+
+ # Return this dict
+ return obj_args
+
+ @property
+ def handle_distance(self):
+
+ """
+ Calculates how far apart the handles are
+
+ Returns:
+ float: handle distance
+ """
+ return self.body_half_size[1] * 2 + self.handle_length * 2
+
+ @property
+ def handle0_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to handle0 (green handle)
+ """
+ return self.correct_naming(self._handle0_geoms)
+
+ @property
+ def handle1_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to handle1 (blue handle)
+ """
+ return self.correct_naming(self._handle1_geoms)
+
+ @property
+ def handle_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to both handles
+ """
+ return self.handle0_geoms + self.handle1_geoms
+
+ @property
+ def important_sites(self):
+ """
+ Returns:
+ dict: In addition to any default sites for this object, also provides the following entries
+
+ :`'handle0'`: Name of handle0 location site
+ :`'handle1'`: Name of handle1 location site
+ """
+ # Get dict from super call and add to it
+ dic = super().important_sites
+ dic.update(self._important_sites)
+ return dic
+
+ @property
+ def bottom_offset(self):
+ return np.array([0, 0, -1 * self.body_half_size[2]])
+
+ @property
+ def top_offset(self):
+ return np.array([0, 0, self.body_half_size[2]])
+
+ @property
+ def horizontal_radius(self):
+ return np.sqrt(2) * (max(self.body_half_size) + self.handle_length)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/stand_with_mount.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/stand_with_mount.py
new file mode 100644
index 0000000000000000000000000000000000000000..903c35a19a2443a58a51b9ce05f8af0b6ea8f8c8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite/stand_with_mount.py
@@ -0,0 +1,199 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.models.objects import CompositeObject
+from robosuite.utils.mjcf_utils import RED, CustomMaterial, add_to_dict
+
+
+class StandWithMount(CompositeObject):
+ """
+ Generates a flat stand with a four-walled mount sticking out of the top.
+ Args:
+ name (str): Name of this object
+ size (3-array): (x,y,z) full size of object
+ mount_location (2-array): (x,y) location to place mount, relative to center of stand
+ mount_width (float): How wide mount is (measured from outside of walls!)
+ wall_thickness (float): How thick to make walls of mount
+ initialize_on_side (bool): If True, will initialize this stand on its side (tipped over)
+ add_hole_vis (bool): If True, adds a rim around the top of the walls, to help make the hole more visually distinctive
+ friction (3-array or None): If specified, sets friction values for this object. None results in default values
+ density (float): Density value to use for all geoms. Defaults to 1000
+ use_texture (bool): If true, geoms will be defined by realistic textures and rgba values will be ignored
+ rgba (4-array or None): If specified, sets rgba values for all geoms. None results in default values
+ """
+
+ def __init__(
+ self,
+ name,
+ size=(0.3, 0.3, 0.15),
+ mount_location=(0.0, 0.0),
+ mount_width=0.05,
+ wall_thickness=0.01,
+ base_thickness=0.01,
+ initialize_on_side=True,
+ add_hole_vis=False,
+ friction=None,
+ density=1000.0,
+ solref=(0.02, 1.0),
+ solimp=(0.9, 0.95, 0.001),
+ use_texture=True,
+ rgba=(0.2, 0.1, 0.0, 1.0),
+ ):
+ # Set name
+ self._name = name
+
+ # Set object attributes
+ self.size = np.array(size)
+ self.mount_location = np.array(mount_location)
+ self.mount_width = mount_width
+ self.wall_thickness = wall_thickness
+ self.base_thickness = base_thickness
+ self.initialize_on_side = initialize_on_side
+ self.add_hole_vis = add_hole_vis
+ self.friction = friction if friction is None else np.array(friction)
+ self.solref = solref
+ self.solimp = solimp
+ self.density = density
+ self.use_texture = use_texture
+ self.rgba = rgba
+ self.mat_name = "brass_mat"
+
+ # Element references
+ self._base_geom = "base"
+
+ # Other private attributes
+ self._important_sites = {}
+
+ # Create dictionary of values to create geoms for composite object and run super init
+ super().__init__(**self._get_geom_attrs())
+
+ # Define materials we want to use for this object
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "3 3",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ bin_mat = CustomMaterial(
+ texture="Brass",
+ tex_name="brass",
+ mat_name=self.mat_name,
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.append_material(bin_mat)
+
+ def _get_geom_attrs(self):
+ """
+ Creates geom elements that will be passed to superclass CompositeObject constructor
+ Returns:
+ dict: args to be used by CompositeObject to generate geoms
+ """
+ # Initialize dict of obj args that we'll pass to the CompositeObject constructor
+ base_args = {
+ "total_size": self.size / 2.0,
+ "name": self.name,
+ "locations_relative_to_center": True,
+ "obj_types": "all",
+ "density": self.density,
+ "solref": self.solref,
+ "solimp": self.solimp,
+ }
+ obj_args = {}
+
+ # Base
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(0, 0, -(self.size[2] - self.base_thickness) / 2),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=np.array((self.size[0], self.size[1], self.base_thickness)) / 2,
+ geom_names=self._base_geom,
+ geom_rgbas=None if self.use_texture else self.rgba,
+ geom_materials=self.mat_name if self.use_texture else None,
+ geom_frictions=self.friction,
+ )
+
+ # Walls
+ x_vals = (
+ np.array(
+ [0, -(self.mount_width - self.wall_thickness) / 2, 0, (self.mount_width - self.wall_thickness) / 2]
+ )
+ + self.mount_location[0]
+ )
+ y_vals = (
+ np.array(
+ [-(self.mount_width - self.wall_thickness) / 2, 0, (self.mount_width - self.wall_thickness) / 2, 0]
+ )
+ + self.mount_location[1]
+ )
+ r_vals = np.array([np.pi / 2, 0, -np.pi / 2, np.pi])
+ for i, (x, y, r) in enumerate(zip(x_vals, y_vals, r_vals)):
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(x, y, self.base_thickness / 2),
+ geom_quats=T.convert_quat(T.axisangle2quat(np.array([0, 0, r])), to="wxyz"),
+ geom_sizes=(self.wall_thickness / 2, self.mount_width / 2, (self.size[2] - self.base_thickness) / 2),
+ geom_names=f"wall{i}",
+ geom_rgbas=None if self.use_texture else self.rgba,
+ geom_materials=self.mat_name if self.use_texture else None,
+ geom_frictions=self.friction,
+ )
+
+ if self.add_hole_vis:
+ # add a purely visual rim
+ del base_args["obj_types"]
+ obj_args["obj_types"] = len(obj_args["geom_types"]) * ["all"]
+
+ vis_geom_side = 0.7 * ((self.mount_width - self.wall_thickness) / 2)
+ vis_geom_size = (vis_geom_side, vis_geom_side, self.wall_thickness / 2)
+ add_to_dict(
+ dic=obj_args,
+ geom_types="box",
+ geom_locations=(self.mount_location[0], self.mount_location[1], (self.size[2] / 2) - vis_geom_size[2]),
+ geom_quats=(1, 0, 0, 0),
+ geom_sizes=vis_geom_size,
+ geom_names="hole_vis",
+ geom_rgbas=(0.0, 1.0, 0.0, 0.5),
+ geom_materials=None,
+ geom_frictions=self.friction,
+ obj_types="visual",
+ )
+
+ # Sites
+ obj_args["sites"] = [
+ {
+ "name": f"mount_site",
+ "pos": (0, 0, self.size[2] / 2),
+ "size": "0.002",
+ "rgba": RED,
+ "type": "sphere",
+ }
+ ]
+
+ # Add back in base args and site args
+ obj_args.update(base_args)
+
+ # Return this dict
+ return obj_args
+
+ @property
+ def init_quat(self):
+ """
+ Optionally rotate the mount on its side so it is flat
+ Returns:
+ np.array: (x, y, z, w) quaternion orientation for this object
+ """
+ # Rotate 90 deg about Y axis if at all
+ return np.array([0, 0.707107, 0, 0.707107]) if self.initialize_on_side else np.array([0, 0, 0, 1])
+
+ @property
+ def base_geoms(self):
+ """
+ Returns:
+ list of str: geom names corresponding to base
+ """
+ return [self.correct_naming(self._base_geom)]
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite_body/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite_body/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..79ccc9a8e74b24e77ac07a4045d2be68e8872167
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite_body/__init__.py
@@ -0,0 +1,2 @@
+from .hinged_box import HingedBoxObject
+from .ratcheting_wrench import RatchetingWrenchObject
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite_body/hinged_box.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite_body/hinged_box.py
new file mode 100644
index 0000000000000000000000000000000000000000..12aa3cd0c62ab94df33033ea0bbcd88693876866
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite_body/hinged_box.py
@@ -0,0 +1,141 @@
+import numpy as np
+
+from robosuite.models.objects import BoxObject, CompositeBodyObject, CylinderObject
+from robosuite.utils.mjcf_utils import BLUE, RED, CustomMaterial, array_to_string
+
+
+class HingedBoxObject(CompositeBodyObject):
+ """
+ An example object that demonstrates the CompositeBodyObject functionality. This object consists of two cube bodies
+ joined together by a hinge joint.
+
+ Args:
+ name (str): Name of this object
+
+ box1_size (3-array): (L, W, H) half-sizes for the first box
+
+ box2_size (3-array): (L, W, H) half-sizes for the second box
+
+ use_texture (bool): set True if using wood textures for the blocks
+ """
+
+ def __init__(
+ self,
+ name,
+ box1_size=(0.025, 0.025, 0.025),
+ box2_size=(0.025, 0.025, 0.0125),
+ use_texture=True,
+ ):
+ # Set box sizes
+ self.box1_size = np.array(box1_size)
+ self.box2_size = np.array(box2_size)
+
+ # Set texture attributes
+ self.use_texture = use_texture
+ self.box1_material = None
+ self.box2_material = None
+ self.box1_rgba = RED
+ self.box2_rgba = BLUE
+
+ # Define materials we want to use for this object
+ if self.use_texture:
+ # Remove RGBAs
+ self.box1_rgba = None
+ self.box2_rgba = None
+
+ # Set materials for each box
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "3 3",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ self.box1_material = CustomMaterial(
+ texture="WoodRed",
+ tex_name="box1_tex",
+ mat_name="box1_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+ self.box2_material = CustomMaterial(
+ texture="WoodBlue",
+ tex_name="box2_tex",
+ mat_name="box2_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+
+ # Create objects
+ objects = []
+ for i, (size, mat, rgba) in enumerate(
+ zip(
+ (self.box1_size, self.box2_size),
+ (self.box1_material, self.box2_material),
+ (self.box1_rgba, self.box2_rgba),
+ )
+ ):
+ objects.append(
+ BoxObject(
+ name=f"box{i + 1}",
+ size=size,
+ rgba=rgba,
+ material=mat,
+ )
+ )
+
+ # Also add hinge for visualization
+ objects.append(
+ CylinderObject(
+ name="hinge",
+ size=np.array(
+ [min(self.box1_size[2], self.box2_size[2]) / 5.0, min(self.box1_size[0], self.box2_size[0])]
+ ),
+ rgba=[0.5, 0.5, 0, 1],
+ obj_type="visual",
+ )
+ )
+
+ # Define hinge joint
+ rel_hinge_pos = [self.box2_size[0], 0, -self.box2_size[2]] # want offset in all except y-axis
+ hinge_joint = {
+ "name": "box_hinge",
+ "type": "hinge",
+ "axis": "0 1 0", # y-axis hinge
+ "pos": array_to_string(rel_hinge_pos),
+ "stiffness": "0.0001",
+ "limited": "true",
+ "range": "0 1.57",
+ }
+
+ # Define positions -- second box should lie on top of first box with edge aligned at hinge joint
+ # Hinge visualizer should be aligned at hinge joint location
+ positions = [
+ np.zeros(3), # First box is centered at top-level body anyways
+ np.array([-(self.box2_size[0] - self.box1_size[0]), 0, self.box1_size[2] + self.box2_size[2]]),
+ np.array(rel_hinge_pos),
+ ]
+
+ quats = [
+ None, # Default quaternion for box 1
+ None, # Default quaternion for box 2
+ [0.707, 0.707, 0, 0], # Rotated 90 deg about x-axis
+ ]
+
+ # Define parents -- which body each is aligned to
+ parents = [
+ None, # box 1 attached to top-level body
+ objects[0].root_body, # box 2 attached to box 1
+ objects[1].root_body, # hinge attached to box 2
+ ]
+
+ # Run super init
+ super().__init__(
+ name=name,
+ objects=objects,
+ object_locations=positions,
+ object_quats=quats,
+ object_parents=parents,
+ body_joints={objects[1].root_body: [hinge_joint]},
+ )
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite_body/ratcheting_wrench.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite_body/ratcheting_wrench.py
new file mode 100644
index 0000000000000000000000000000000000000000..6686541956d90f7fd06125332bbc342e6c7b5ac1
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/composite_body/ratcheting_wrench.py
@@ -0,0 +1,150 @@
+import numpy as np
+
+from robosuite.models.objects import BoxObject, CompositeBodyObject, CylinderObject, HollowCylinderObject
+from robosuite.utils.mjcf_utils import CustomMaterial
+
+
+class RatchetingWrenchObject(CompositeBodyObject):
+ """
+ A ratcheting wrench made out of mujoco primitives.
+ Args:
+ name (str): Name of this object
+ handle_size ([float]): (L, W, H) half-sizes for the handle (center part of wrench)
+ outer_radius_1 (float): Outer radius of first end of wrench
+ inner_radius_1 (float): Inner radius of first end of wrench
+ height_1 (float): Height of first end of wrench
+ outer_radius_2 (float): Outer radius of second end of wrench
+ inner_radius_2 (float): Inner radius of second end of wrench
+ height_2 (float): Height of second end of wrench
+ ngeoms (int): Number of box geoms used to approximate the ends of the wrench. Use
+ more geoms to make the approximation better.
+ grip_size ([float]): (R, H) radius and half-height for the box grip. Set to None
+ to not add a grip.
+ """
+
+ def __init__(
+ self,
+ name,
+ handle_size=(0.08, 0.01, 0.005),
+ outer_radius_1=0.0425,
+ inner_radius_1=0.03,
+ height_1=0.05,
+ outer_radius_2=0.0425,
+ inner_radius_2=0.03,
+ height_2=0.05,
+ ngeoms=8,
+ grip_size=None,
+ # rgba=None,
+ density=1000.0,
+ solref=(0.02, 1.0),
+ solimp=(0.9, 0.95, 0.001),
+ friction=None,
+ ):
+ # Object properties
+ self.handle_size = tuple(handle_size)
+ self.outer_radii = (outer_radius_1, outer_radius_2)
+ self.inner_radii = (inner_radius_1, inner_radius_2)
+ self.heights = (height_1, height_2)
+ self.ngeoms = ngeoms
+ self.grip_size = tuple(grip_size) if grip_size is not None else None
+
+ # Define materials we want to use for this object
+ tex_attrib = {
+ "type": "cube",
+ }
+ mat_attrib = {
+ "texrepeat": "3 3",
+ "specular": "0.4",
+ "shininess": "0.1",
+ }
+ wrench_mat = CustomMaterial(
+ texture="SteelScratched",
+ tex_name="steel",
+ mat_name="steel_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+
+ if self.grip_size is not None:
+ grip_mat = CustomMaterial(
+ texture="Ceramic",
+ tex_name="ceramic",
+ mat_name="ceramic_mat",
+ tex_attrib=tex_attrib,
+ mat_attrib=mat_attrib,
+ )
+
+ # Create objects
+ objects = []
+
+ # each end of the wrench is modeled by a hollow cylinder
+ for i in range(2):
+ objects.append(
+ HollowCylinderObject(
+ name=f"hole{i + 1}",
+ outer_radius=self.outer_radii[i],
+ inner_radius=self.inner_radii[i],
+ height=self.heights[i],
+ ngeoms=self.ngeoms,
+ rgba=None,
+ material=wrench_mat,
+ density=density,
+ solref=solref,
+ solimp=solimp,
+ friction=friction,
+ make_half=False,
+ )
+ )
+
+ # also add center box geom for handle
+ objects.append(
+ BoxObject(
+ name="handle",
+ size=handle_size,
+ rgba=None,
+ material=wrench_mat,
+ density=density,
+ solref=solref,
+ solimp=solimp,
+ friction=friction,
+ )
+ )
+
+ # Define positions (top-level body is centered at handle)
+ hole_1_box_geom_height = 2.0 * objects[0].unit_box_height
+ hole_2_box_geom_height = 2.0 * objects[1].unit_box_height
+ positions = [
+ # this computation ensures no gaps between the center bar geom and the two wrench holes at the end
+ np.array([-handle_size[0] - self.outer_radii[0] + hole_1_box_geom_height, 0, 0]),
+ np.array([handle_size[0] + self.outer_radii[1] - hole_2_box_geom_height, 0, 0]),
+ np.zeros(3),
+ ]
+ quats = [None, None, None]
+ parents = [None, None, None]
+
+ # maybe add grip
+ if self.grip_size is not None:
+ objects.append(
+ BoxObject(
+ name="grip",
+ size=[self.grip_size[0], self.grip_size[0], self.grip_size[1]],
+ rgba=(0.13, 0.13, 0.13, 1.0),
+ density=density,
+ solref=solref,
+ solimp=solimp,
+ friction=(1.0, 0.005, 0.0001), # use default friction
+ )
+ )
+ positions.append(np.zeros(3))
+ quats.append((np.sqrt(2) / 2.0, 0.0, np.sqrt(2) / 2.0, 0.0)) # rotate 90 degrees about y-axis
+ parents.append(None)
+
+ # Run super init
+ super().__init__(
+ name=name,
+ objects=objects,
+ object_locations=positions,
+ object_quats=quats,
+ object_parents=parents,
+ joints=[dict(type="free", damping="0.0005")], # be consistent with round-nut.xml
+ )
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/generated_objects.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/generated_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..770624b848099721f9d86ec8193e2cbd0c6e6ca2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/generated_objects.py
@@ -0,0 +1,785 @@
+from copy import deepcopy
+
+import numpy as np
+
+from robosuite.models.objects import MujocoGeneratedObject, MujocoObject
+from robosuite.utils.mjcf_utils import (
+ OBJECT_COLLISION_COLOR,
+ CustomMaterial,
+ add_prefix,
+ array_to_string,
+ find_elements,
+ new_body,
+ new_geom,
+ new_joint,
+ new_site,
+)
+
+
+class CompositeBodyObject(MujocoGeneratedObject):
+ """
+ An object constructed out of multiple bodies to make more complex shapes.
+
+ Args:
+ name (str): Name of overall object
+
+ objects (MujocoObject or list of MujocoObjects): object(s) to combine to form the composite body object.
+ Note that these objects will be added sequentially, so if an object is required to be nested relative to
+ another object, that nested object should be listed after the parent object. Note that all top-level joints
+ for any inputted objects are automatically stripped
+
+ object_locations (list): list of body locations in the composite. Each
+ location should be a list or tuple of 3 elements and all
+ locations are taken relative to that object's parent body. Giving None for a location results in (0,0,0)
+ for that object.
+
+ object_quats (None or list): list of (w, x, y, z) quaternions for each body. None results in (1,0,0,0) for
+ that object.
+
+ object_parents (None or list): Parent bodies to append each object to. Note that specifying "None" will
+ automatically append all objects to the root body ("root")
+
+ joints (None or list): Joints to use for the top-level composite body object. If None, no joints will be used
+ for this top-level object. If "default", a single free joint will be added to the top-level body of this
+ object. Otherwise, should be a list of dictionaries, where each dictionary should specify the specific
+ joint attributes necessary. See http://www.mujoco.org/book/XMLreference.html#joint for reference.
+
+ body_joints (None or dict): If specified, maps body names to joint specifications to append to that
+ body. If None, no extra joints will be used. If mapped value is "default", a single free joint will be
+ added to the specified body. Otherwise, should be a list of dictionaries, where each dictionary should
+ specify the specific joint attributes necessary. See http://www.mujoco.org/book/XMLreference.html#joint
+ for reference.
+
+ sites (None or list): list of sites to add to top-level composite body object. If None, only the default
+ top-level object site will be used. Otherwise, should be a list of dictionaries, where each dictionary
+ should specify the appropriate attributes for the given site.
+ See http://www.mujoco.org/book/XMLreference.html#site for reference.
+
+ total_size (None or np.array): if provided, use this to describe the bounding box for this composite body
+ object. Can also be used to specify @object_locations relative to the lower left corner of the bounding
+ box defined by @total_size, instead of the center of this body, with @locations_relative_to_corner.
+
+ locations_relative_to_corner (bool): if True, must supply @total_size. All object locations will be
+ relative to the lower left corner of the bounding box.
+ """
+
+ def __init__(
+ self,
+ name,
+ objects,
+ object_locations,
+ object_quats=None,
+ object_parents=None,
+ joints="default",
+ body_joints=None,
+ sites=None,
+ total_size=None,
+ locations_relative_to_corner=False,
+ ):
+ # Always call superclass first
+ super().__init__()
+
+ self._name = name
+
+ # Set internal variable geometric properties which will be modified later
+ self._object_absolute_positions = {"root": np.zeros(3)} # maps body names to abs positions (rel to root)
+ self._top = 0
+ self._bottom = 0
+ self._horizontal = 0
+
+ # Standardize inputs
+ if isinstance(objects, MujocoObject):
+ self.objects = [objects]
+ elif type(objects) in {list, tuple}:
+ self.objects = list(objects)
+ else:
+ # Invalid objects received
+ raise ValueError("Invalid objects received, got type: {}".format(type(objects)))
+
+ n_objects = len(self.objects)
+ self.object_locations = np.array(object_locations)
+ self.object_quats = deepcopy(object_quats) if object_quats is not None else [None] * n_objects
+ self.object_parents = deepcopy(object_parents) if object_parents is not None else ["root"] * n_objects
+
+ # Set joints
+ if joints == "default":
+ self.joint_specs = [self.get_joint_attrib_template()] # default free joint
+ elif joints is None:
+ self.joint_specs = []
+ else:
+ self.joint_specs = joints
+
+ # Set body joints
+ if body_joints is None:
+ body_joints = {}
+ self.body_joint_specs = body_joints
+
+ # Make sure all joints are named appropriately
+ j_num = 0
+ for joint_spec in self.joint_specs:
+ if "name" not in joint_spec:
+ joint_spec["name"] = "joint{}".format(j_num)
+ j_num += 1
+
+ # Set sites
+ self.site_specs = deepcopy(sites) if sites is not None else []
+ # Add default site
+ site_element_attr = self.get_site_attrib_template()
+ site_element_attr["rgba"] = "1 0 0 0"
+ site_element_attr["name"] = "default_site"
+ self.site_specs.append(site_element_attr)
+
+ # Make sure all sites are named appropriately
+ s_num = 0
+ for site_spec in self.site_specs:
+ if "name" not in site_spec:
+ site_spec["name"] = "site{}".format(s_num)
+ s_num += 1
+
+ self.total_size = np.array(total_size) if total_size is not None else None
+ self.locations_relative_to_corner = locations_relative_to_corner
+ if self.locations_relative_to_corner:
+ assert self.total_size is not None
+
+ # Always run sanity check
+ self.sanity_check()
+
+ # Lastly, parse XML tree appropriately
+ self._obj = self._get_object_subtree()
+
+ # Extract the appropriate private attributes for this
+ self._get_object_properties()
+
+ def _get_object_subtree(self):
+ # Initialize top-level body
+ obj = new_body(name="root")
+
+ # Add all joints and sites
+ for joint_spec in self.joint_specs:
+ obj.append(new_joint(**joint_spec))
+ for site_spec in self.site_specs:
+ obj.append(new_site(**site_spec))
+
+ # Loop through all objects and associated args and append them appropriately
+ for o, o_parent, o_pos, o_quat in zip(
+ self.objects, self.object_parents, self.object_locations, self.object_quats
+ ):
+ self._append_object(root=obj, obj=o, parent_name=o_parent, pos=o_pos, quat=o_quat)
+
+ # Loop through all joints and append them appropriately
+ for body_name, joint_specs in self.body_joint_specs.items():
+ self._append_joints(root=obj, body_name=body_name, joint_specs=joint_specs)
+
+ # Return final object
+ return obj
+
+ def _get_object_properties(self):
+ """
+ Extends the superclass method to add prefixes to all assets
+ """
+ super()._get_object_properties()
+ # Add prefix to all assets
+ add_prefix(root=self.asset, prefix=self.naming_prefix, exclude=self.exclude_from_prefixing)
+
+ def _append_object(self, root, obj, parent_name=None, pos=None, quat=None):
+ """
+ Helper function to add pre-generated object @obj to the body with name @parent_name
+
+ Args:
+ root (ET.Element): Top-level element to iteratively search through for @parent_name to add @obj to
+ obj (MujocoObject): Object to append to the body specified by @parent_name
+ parent_name (None or str): Body name to search for in @root to append @obj to.
+ None defaults to "root" (top-level body)
+ pos (None or 3-array): (x,y,z) relative offset from parent body when appending @obj.
+ None defaults to (0,0,0)
+ quat (None or 4-array) (w,x,y,z) relative quaternion rotation from parent body when appending @obj.
+ None defaults to (1,0,0,0)
+ """
+ # Set defaults if any are None
+ if parent_name is None:
+ parent_name = "root"
+ if pos is None:
+ pos = np.zeros(3)
+ if quat is None:
+ quat = np.array([1, 0, 0, 0])
+ # First, find parent body
+ parent = find_elements(root=root, tags="body", attribs={"name": parent_name}, return_first=True)
+ assert parent is not None, "Could not find parent body with name: {}".format(parent_name)
+ # Get the object xml element tree, remove its top-level joints, and modify its top-level pos / quat
+ child = obj.get_obj()
+ self._remove_joints(child)
+
+ if self.locations_relative_to_corner:
+ # use object location to convert to position coordinate (the origin is the
+ # center of the composite object)
+ cartesian_size = obj.get_bounding_box_half_size()
+ pos = [
+ (-self.total_size[0] + cartesian_size[0]) + pos[0],
+ (-self.total_size[1] + cartesian_size[1]) + pos[1],
+ (-self.total_size[2] + cartesian_size[2]) + pos[2],
+ ]
+
+ child.set("pos", array_to_string(pos))
+ child.set("quat", array_to_string(quat))
+ # Add this object and its assets to this composite object
+ self.merge_assets(other=obj)
+ parent.append(child)
+ # Update geometric properties for this composite object
+ obj_abs_pos = self._object_absolute_positions[parent_name] + np.array(pos)
+ self._object_absolute_positions[obj.root_body] = obj_abs_pos
+ self._top = max(self._top, obj_abs_pos[2] + obj.top_offset[2])
+ self._bottom = min(self._bottom, obj_abs_pos[2] + obj.bottom_offset[2])
+ self._horizontal = max(self._horizontal, max(obj_abs_pos[:2]) + obj.horizontal_radius)
+
+ def _append_joints(self, root, body_name=None, joint_specs="default"):
+ """
+ Appends all joints as specified by @joint_specs to @body.
+
+ Args:
+ root (ET.Element): Top-level element to iteratively search through for @body_name
+ body_name (None or str): Name of the body to append the joints to.
+ None defaults to "root" (top-level body)
+ joint_specs (str or list): List of joint specifications to add to the specified body, or
+ "default", which results in a single free joint
+ """
+ # Standardize inputs
+ if body_name is None:
+ body_name = "root"
+ if joint_specs == "default":
+ joint_specs = [self.get_joint_attrib_template()]
+ for i, joint_spec in enumerate(joint_specs):
+ if "name" not in joint_spec:
+ joint_spec["name"] = f"{body_name}_joint{i}"
+ # Search for body and make sure it exists
+ body = find_elements(root=root, tags="body", attribs={"name": body_name}, return_first=True)
+ assert body is not None, "Could not find body with name: {}".format(body_name)
+ # Add joint(s) to this body
+ for joint_spec in joint_specs:
+ body.append(new_joint(**joint_spec))
+
+ @staticmethod
+ def _remove_joints(body):
+ """
+ Helper function to strip all joints directly appended to the specified @body.
+
+ Args:
+ body (ET.Element): Body to strip joints from
+ """
+ children_to_remove = []
+ for child in body:
+ if child.tag == "joint":
+ children_to_remove.append(child)
+ for child in children_to_remove:
+ body.remove(child)
+
+ @property
+ def bottom_offset(self):
+ return np.array([0.0, 0.0, self._bottom])
+
+ @property
+ def top_offset(self):
+ return np.array([0.0, 0.0, self._top])
+
+ @property
+ def horizontal_radius(self):
+ return self._horizontal
+
+ def get_bounding_box_half_size(self):
+ if self.total_size is not None:
+ return np.array(self.total_size)
+ return super().get_bounding_box_half_size()
+
+
+class CompositeObject(MujocoGeneratedObject):
+ """
+ An object constructed out of basic geoms to make more intricate shapes.
+
+ Note that by default, specifying None for a specific geom element will usually set a value to the mujoco defaults.
+
+ Args:
+ name (str): Name of overall object
+
+ total_size (list): (x, y, z) half-size in each dimension for the bounding box for
+ this Composite object
+
+ geom_types (list): list of geom types in the composite. Must correspond
+ to MuJoCo geom primitives, such as "box" or "capsule".
+
+ geom_locations (list): list of geom locations in the composite. Each
+ location should be a list or tuple of 3 elements and all
+ locations are relative to the lower left corner of the total box
+ (e.g. (0, 0, 0) corresponds to this corner).
+
+ geom_sizes (list): list of geom sizes ordered the same as @geom_locations
+
+ geom_quats (None or list): list of (w, x, y, z) quaternions for each geom.
+
+ geom_names (None or list): list of geom names ordered the same as @geom_locations. The
+ names will get appended with an underscore to the passed name in @get_collision
+ and @get_visual
+
+ geom_rgbas (None or list): list of geom colors ordered the same as @geom_locations. If
+ passed as an argument, @rgba is ignored.
+
+ geom_materials (None or list of CustomTexture): list of custom textures to use for this object material
+
+ geom_frictions (None or list): list of geom frictions to use for each geom.
+
+ rgba (None or list): (r, g, b, a) default values to use if geom-specific @geom_rgbas isn't specified for a given element
+
+ density (float or list of float): either single value to use for all geom densities or geom-specific values
+
+ solref (list or list of list): parameters used for the mujoco contact solver. Can be single set of values or
+ element-specific values. See http://www.mujoco.org/book/modeling.html#CSolver for details.
+
+ solimp (list or list of list): parameters used for the mujoco contact solver. Can be single set of values or
+ element-specific values. See http://www.mujoco.org/book/modeling.html#CSolver for details.
+
+ locations_relative_to_center (bool): If true, @geom_locations will be considered relative to the center of the
+ overall object bounding box defined by @total_size. Else, the corner of this bounding box is considered the
+ origin.
+
+ joints (None or list): Joints to use for this composite object. If None, no joints will be used
+ for this top-level object. If "default", a single free joint will be added to this object.
+ Otherwise, should be a list of dictionaries, where each dictionary should specify the specific
+ joint attributes necessary. See http://www.mujoco.org/book/XMLreference.html#joint for reference.
+
+ sites (None or list): list of sites to add to this composite object. If None, only the default
+ object site will be used. Otherwise, should be a list of dictionaries, where each dictionary
+ should specify the appropriate attributes for the given site.
+ See http://www.mujoco.org/book/XMLreference.html#site for reference.
+
+ obj_types (str or list of str): either single obj_type for all geoms or geom-specific type. Choices are
+ {"collision", "visual", "all"}
+ """
+
+ def __init__(
+ self,
+ name,
+ total_size,
+ geom_types,
+ geom_sizes,
+ geom_locations,
+ geom_quats=None,
+ geom_names=None,
+ geom_rgbas=None,
+ geom_materials=None,
+ geom_frictions=None,
+ geom_condims=None,
+ rgba=None,
+ density=100.0,
+ solref=(0.02, 1.0),
+ solimp=(0.9, 0.95, 0.001),
+ locations_relative_to_center=False,
+ joints="default",
+ sites=None,
+ obj_types="all",
+ duplicate_collision_geoms=True,
+ ):
+ # Always call superclass first
+ super().__init__(duplicate_collision_geoms=duplicate_collision_geoms)
+
+ self._name = name
+
+ # Set joints
+ if joints == "default":
+ self.joint_specs = [self.get_joint_attrib_template()] # default free joint
+ elif joints is None:
+ self.joint_specs = []
+ else:
+ self.joint_specs = joints
+
+ # Make sure all joints are named appropriately
+ j_num = 0
+ for joint_spec in self.joint_specs:
+ if "name" not in joint_spec:
+ joint_spec["name"] = "joint{}".format(j_num)
+ j_num += 1
+
+ # Set sites
+ self.site_specs = deepcopy(sites) if sites is not None else []
+ # Add default site
+ site_element_attr = self.get_site_attrib_template()
+ site_element_attr["rgba"] = "1 0 0 0"
+ site_element_attr["name"] = "default_site"
+ self.site_specs.append(site_element_attr)
+
+ # Make sure all sites are named appropriately
+ s_num = 0
+ for site_spec in self.site_specs:
+ if "name" not in site_spec:
+ site_spec["name"] = "site{}".format(s_num)
+ s_num += 1
+
+ n_geoms = len(geom_types)
+ self.total_size = np.array(total_size)
+ self.geom_types = np.array(geom_types)
+ self.geom_sizes = deepcopy(geom_sizes)
+ self.geom_locations = np.array(geom_locations)
+ self.geom_quats = deepcopy(geom_quats) if geom_quats is not None else [None] * n_geoms
+ self.geom_names = list(geom_names) if geom_names is not None else [None] * n_geoms
+ self.geom_rgbas = list(geom_rgbas) if geom_rgbas is not None else [None] * n_geoms
+ self.geom_materials = list(geom_materials) if geom_materials is not None else [None] * n_geoms
+ self.geom_frictions = list(geom_frictions) if geom_frictions is not None else [None] * n_geoms
+ self.geom_condims = list(geom_condims) if geom_condims is not None else [None] * n_geoms
+ self.density = [density] * n_geoms if density is None or type(density) in {float, int} else list(density)
+ self.solref = [solref] * n_geoms if solref is None or type(solref[0]) in {float, int} else list(solref)
+ self.solimp = [solimp] * n_geoms if obj_types is None or type(solimp[0]) in {float, int} else list(solimp)
+ self.rgba = rgba # override superclass setting of this variable
+ self.locations_relative_to_center = locations_relative_to_center
+ self.obj_types = [obj_types] * n_geoms if obj_types is None or type(obj_types) is str else list(obj_types)
+
+ # Always run sanity check
+ self.sanity_check()
+
+ # Lastly, parse XML tree appropriately
+ self._obj = self._get_object_subtree()
+
+ # Extract the appropriate private attributes for this
+ self._get_object_properties()
+
+ def get_bounding_box_half_size(self):
+ return np.array(self.total_size)
+
+ def in_box(self, position, object_position):
+ """
+ Checks whether the object is contained within this CompositeObject.
+ Useful for when the CompositeObject has holes and the object should
+ be within one of the holes. Makes an approximation by treating the
+ object as a point, and the CompositeBoxObject as an axis-aligned grid.
+ Args:
+ position: 3D body position of CompositeObject
+ object_position: 3D position of object to test for insertion
+ """
+ ub = position + self.total_size
+ lb = position - self.total_size
+
+ # fudge factor for the z-check, since after insertion the object falls to table
+ lb[2] -= 0.01
+
+ return np.all(object_position > lb) and np.all(object_position < ub)
+
+ def _get_object_subtree(self):
+ # Initialize top-level body
+ obj = new_body(name="root")
+
+ # Add all joints and sites
+ for joint_spec in self.joint_specs:
+ obj.append(new_joint(**joint_spec))
+ for site_spec in self.site_specs:
+ obj.append(new_site(**site_spec))
+
+ # Loop through all geoms and generate the composite object
+ for i, (
+ obj_type,
+ g_type,
+ g_size,
+ g_loc,
+ g_name,
+ g_rgba,
+ g_friction,
+ g_condim,
+ g_quat,
+ g_material,
+ g_density,
+ g_solref,
+ g_solimp,
+ ) in enumerate(
+ zip(
+ self.obj_types,
+ self.geom_types,
+ self.geom_sizes,
+ self.geom_locations,
+ self.geom_names,
+ self.geom_rgbas,
+ self.geom_frictions,
+ self.geom_condims,
+ self.geom_quats,
+ self.geom_materials,
+ self.density,
+ self.solref,
+ self.solimp,
+ )
+ ):
+ # geom type
+ geom_type = g_type
+ # get cartesian size from size spec
+ size = g_size
+ cartesian_size = self._size_to_cartesian_half_lengths(geom_type, size)
+ if self.locations_relative_to_center:
+ # no need to convert
+ pos = g_loc
+ else:
+ # use geom location to convert to position coordinate (the origin is the
+ # center of the composite object)
+ pos = [
+ (-self.total_size[0] + cartesian_size[0]) + g_loc[0],
+ (-self.total_size[1] + cartesian_size[1]) + g_loc[1],
+ (-self.total_size[2] + cartesian_size[2]) + g_loc[2],
+ ]
+
+ # geom name
+ geom_name = g_name if g_name is not None else f"g{i}"
+
+ # geom rgba
+ geom_rgba = g_rgba if g_rgba is not None else self.rgba
+
+ # geom friction
+ geom_friction = (
+ array_to_string(g_friction)
+ if g_friction is not None
+ else array_to_string(np.array([1.0, 0.005, 0.0001]))
+ ) # mujoco default
+
+ # Define base geom attributes
+ geom_attr = {
+ "size": size,
+ "pos": pos,
+ "name": geom_name,
+ "type": geom_type,
+ }
+
+ # Optionally define quat if specified
+ if g_quat is not None:
+ geom_attr["quat"] = array_to_string(g_quat)
+
+ # Add collision geom if necessary
+ if obj_type in {"collision", "all"}:
+ col_geom_attr = deepcopy(geom_attr)
+ col_geom_attr.update(self.get_collision_attrib_template())
+ if g_density is not None:
+ col_geom_attr["density"] = str(g_density)
+ col_geom_attr["friction"] = geom_friction
+ col_geom_attr["solref"] = array_to_string(g_solref)
+ col_geom_attr["solimp"] = array_to_string(g_solimp)
+ col_geom_attr["rgba"] = OBJECT_COLLISION_COLOR
+ if g_condim is not None:
+ col_geom_attr["condim"] = str(g_condim)
+ obj.append(new_geom(**col_geom_attr))
+
+ # Add visual geom if necessary
+ if obj_type in {"visual", "all"}:
+ vis_geom_attr = deepcopy(geom_attr)
+ vis_geom_attr.update(self.get_visual_attrib_template())
+ vis_geom_attr["name"] += "_vis"
+ if g_material is not None:
+ vis_geom_attr["material"] = g_material
+ vis_geom_attr["rgba"] = geom_rgba
+ obj.append(new_geom(**vis_geom_attr))
+
+ return obj
+
+ @staticmethod
+ def _size_to_cartesian_half_lengths(geom_type, geom_size):
+ """
+ converts from geom size specification to x, y, and z half-length bounding box
+ """
+ if geom_type in ["box", "ellipsoid"]:
+ return geom_size
+ if geom_type == "sphere":
+ # size is radius
+ return [geom_size[0], geom_size[0], geom_size[0]]
+ if geom_type == "capsule":
+ # size is radius, half-length of cylinder part
+ return [geom_size[0], geom_size[0], geom_size[0] + geom_size[1]]
+ if geom_type == "cylinder":
+ # size is radius, half-length
+ return [geom_size[0], geom_size[0], geom_size[1]]
+ raise Exception("unsupported geom type!")
+
+ @property
+ def bottom_offset(self):
+ return np.array([0.0, 0.0, -self.total_size[2]])
+
+ @property
+ def top_offset(self):
+ return np.array([0.0, 0.0, self.total_size[2]])
+
+ @property
+ def horizontal_radius(self):
+ return np.linalg.norm(self.total_size[:2], 2)
+
+
+class PrimitiveObject(MujocoGeneratedObject):
+ """
+ Base class for all programmatically generated mujoco object
+ i.e., every MujocoObject that does not have an corresponding xml file
+
+ Args:
+ name (str): (unique) name to identify this generated object
+
+ size (n-tuple of float): relevant size parameters for the object, should be of size 1 - 3
+
+ rgba (4-tuple of float): Color
+
+ density (float): Density
+
+ friction (3-tuple of float): (sliding friction, torsional friction, and rolling friction).
+ A single float can also be specified, in order to set the sliding friction (the other values) will
+ be set to the MuJoCo default. See http://www.mujoco.org/book/modeling.html#geom for details.
+
+ solref (2-tuple of float): MuJoCo solver parameters that handle contact.
+ See http://www.mujoco.org/book/XMLreference.html for more details.
+
+ solimp (3-tuple of float): MuJoCo solver parameters that handle contact.
+ See http://www.mujoco.org/book/XMLreference.html for more details.
+
+ material (CustomMaterial or `'default'` or None): if "default", add a template material and texture for this
+ object that is used to color the geom(s).
+ Otherwise, input is expected to be a CustomMaterial object
+
+ See http://www.mujoco.org/book/XMLreference.html#asset for specific details on attributes expected for
+ Mujoco texture / material tags, respectively
+
+ Note that specifying a custom texture in this way automatically overrides any rgba values set
+
+ joints (None or str or list of dict): Joints for this object. If None, no joint will be created. If "default",
+ a single (free) joint will be crated. Else, should be a list of dict, where each dictionary corresponds to
+ a joint that will be created for this object. The dictionary should specify the joint attributes
+ (type, pos, etc.) according to the MuJoCo xml specification.
+
+ obj_type (str): Geom elements to generate / extract for this object. Must be one of:
+
+ :`'collision'`: Only collision geoms are returned (this corresponds to group 0 geoms)
+ :`'visual'`: Only visual geoms are returned (this corresponds to group 1 geoms)
+ :`'all'`: All geoms are returned
+
+ duplicate_collision_geoms (bool): If set, will guarantee that each collision geom has a
+ visual geom copy
+ """
+
+ def __init__(
+ self,
+ name,
+ size=None,
+ rgba=None,
+ density=None,
+ friction=None,
+ solref=None,
+ solimp=None,
+ material=None,
+ joints="default",
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ ):
+ # Always call superclass first
+ super().__init__(obj_type=obj_type, duplicate_collision_geoms=duplicate_collision_geoms)
+
+ # Set name
+ self._name = name
+
+ if size is None:
+ size = [0.05, 0.05, 0.05]
+ self.size = list(size)
+
+ if rgba is None:
+ rgba = [1, 0, 0, 1]
+ assert len(rgba) == 4, "rgba must be a length 4 array"
+ self.rgba = list(rgba)
+
+ if density is None:
+ density = 1000 # water
+ self.density = density
+
+ if friction is None:
+ friction = [1, 0.005, 0.0001] # MuJoCo default
+ elif isinstance(friction, float) or isinstance(friction, int):
+ friction = [friction, 0.005, 0.0001]
+ assert len(friction) == 3, "friction must be a length 3 array or a single number"
+ self.friction = list(friction)
+
+ if solref is None:
+ self.solref = [0.02, 1.0] # MuJoCo default
+ else:
+ self.solref = solref
+
+ if solimp is None:
+ self.solimp = [0.9, 0.95, 0.001] # MuJoCo default
+ else:
+ self.solimp = solimp
+
+ self.material = material
+ if material == "default":
+ # add in default texture and material for this object (for domain randomization)
+ default_tex = CustomMaterial(
+ texture=self.rgba,
+ tex_name="tex",
+ mat_name="mat",
+ )
+ self.append_material(default_tex)
+ elif material is not None:
+ # add in custom texture and material
+ self.append_material(material)
+
+ # joints for this object
+ if joints == "default":
+ self.joint_specs = [self.get_joint_attrib_template()] # default free joint
+ elif joints is None:
+ self.joint_specs = []
+ else:
+ self.joint_specs = joints
+
+ # Make sure all joints have names!
+ for i, joint_spec in enumerate(self.joint_specs):
+ if "name" not in joint_spec:
+ joint_spec["name"] = "joint{}".format(i)
+
+ # Always run sanity check
+ self.sanity_check()
+
+ # Lastly, parse XML tree appropriately
+ self._obj = self._get_object_subtree()
+
+ # Extract the appropriate private attributes for this
+ self._get_object_properties()
+
+ def _get_object_subtree_(self, ob_type="box"):
+ # Create element tree
+ obj = new_body(name="main")
+
+ # Get base element attributes
+ element_attr = {"name": "g0", "type": ob_type, "size": array_to_string(self.size)}
+
+ # Add collision geom if necessary
+ if self.obj_type in {"collision", "all"}:
+ col_element_attr = deepcopy(element_attr)
+ col_element_attr.update(self.get_collision_attrib_template())
+ col_element_attr["density"] = str(self.density)
+ col_element_attr["friction"] = array_to_string(self.friction)
+ col_element_attr["solref"] = array_to_string(self.solref)
+ col_element_attr["solimp"] = array_to_string(self.solimp)
+ obj.append(new_geom(**col_element_attr))
+ # Add visual geom if necessary
+ if self.obj_type in {"visual", "all"}:
+ vis_element_attr = deepcopy(element_attr)
+ vis_element_attr.update(self.get_visual_attrib_template())
+ vis_element_attr["name"] += "_vis"
+ if self.material == "default":
+ vis_element_attr["rgba"] = "0.5 0.5 0.5 1" # mujoco default
+ vis_element_attr["material"] = "mat"
+ elif self.material is not None:
+ vis_element_attr["material"] = self.material.mat_attrib["name"]
+ else:
+ vis_element_attr["rgba"] = array_to_string(self.rgba)
+ obj.append(new_geom(**vis_element_attr))
+ # add joint(s)
+ for joint_spec in self.joint_specs:
+ obj.append(new_joint(**joint_spec))
+ # add a site as well
+ site_element_attr = self.get_site_attrib_template()
+ site_element_attr["name"] = "default_site"
+ obj.append(new_site(**site_element_attr))
+ return obj
+
+ # Methods that still need to be defined by subclass
+ def _get_object_subtree(self):
+ raise NotImplementedError
+
+ def bottom_offset(self):
+ raise NotImplementedError
+
+ def top_offset(self):
+ raise NotImplementedError
+
+ def horizontal_radius(self):
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/group/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/group/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6f000410cd83b63dd4c255d31c6722cfcd319de
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/group/__init__.py
@@ -0,0 +1 @@
+from .transport import TransportGroup
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/group/transport.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/group/transport.py
new file mode 100644
index 0000000000000000000000000000000000000000..58c3bd894dc6d67f41f5e0246f122d97614fa1d8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/group/transport.py
@@ -0,0 +1,174 @@
+import numpy as np
+
+import robosuite.utils.sim_utils as SU
+import robosuite.utils.transform_utils as T
+from robosuite.models.objects import Bin, Lid, ObjectGroup
+
+
+class TransportGroup(ObjectGroup):
+ """
+ Group of objects that capture transporting a payload placed in a start bin to a target bin, while
+ also requiring a piece of trash to be removed from the target bin
+ Args:
+ name (str): Name of that will the prepended to all geom bodies generated for this group
+ payload (MujocoObject): Object that represents payload
+ trash (MujocoObject): Object that represents trash
+ bin_size (3-tuple): (x,y,z) full size of bins to place on tables
+ """
+
+ def __init__(self, name, payload, trash, bin_size=(0.3, 0.3, 0.15)):
+ # Store and initialize internal variables
+ self.payload = payload
+ self.trash = trash
+ self.bin_size = bin_size
+
+ # Create bins and lid
+ self.start_bin = Bin(name=f"{name}_start_bin", bin_size=bin_size, density=10000.0)
+ self.target_bin = Bin(name=f"{name}_target_bin", bin_size=bin_size, density=10000.0)
+ self.trash_bin = Bin(name=f"{name}_trash_bin", bin_size=bin_size, density=10000.0)
+ self.lid = Lid(name=f"{name}_start_bin_lid", lid_size=(*bin_size[:2], 0.01))
+
+ # Relevant geom ids
+ self.payload_geom_ids = None
+ self.trash_geom_ids = None
+ self.target_bin_base_geom_ids = None
+ self.trash_bin_base_geom_ids = None
+ self.lid_handle_geom_ids = None
+ self.payload_body_id = None
+ self.trash_body_id = None
+
+ # Run super init
+ super().__init__(name=name)
+
+ def get_states(self):
+ """
+ Grabs all relevant information for this transport group. Returned dictionary maps keywords to corresponding
+ values pulled from the current sim state.
+ Returns:
+ dict:
+ "lid_handle_pose": list of (pos, quat) of lid handle
+ "payload_pose": list of (pos, quat) of hammer handle
+ "trash_pose": list of (pos, quat) of trash object
+ "target_bin_pos": position of target bin (base geom)
+ "trash_bin_pos": position of trash bin (base geom)
+ "trash_in_trash_bin": True if trash object is touching the base of the trash bin
+ "payload_in_target_bin": True if payload object is touching the base of the target bin
+ """
+ return {
+ "lid_handle_pose": (self.lid_handle_pos, self.lid_handle_quat),
+ "payload_pose": (self.payload_pos, self.payload_quat),
+ "trash_pose": (self.trash_pos, self.trash_quat),
+ "target_bin_pos": self.target_bin_pos,
+ "trash_bin_pos": self.trash_bin_pos,
+ "trash_in_trash_bin": self.trash_in_trash_bin,
+ "payload_in_target_bin": self.payload_in_target_bin,
+ }
+
+ def _generate_objects(self):
+ # Store all relevant objects in self._objects
+ self._objects = {
+ "payload": self.payload,
+ "trash": self.trash,
+ "start_bin": self.start_bin,
+ "target_bin": self.target_bin,
+ "trash_bin": self.trash_bin,
+ "lid": self.lid,
+ }
+
+ def update_sim(self, sim):
+ """
+ Updates internal reference to sim and all other references
+ Args:
+ sim (MjSim): Active mujoco sim reference
+ """
+ # Always run super first
+ super().update_sim(sim=sim)
+
+ # Update internal references to IDs
+ self.payload_geom_ids = [self.sim.model.geom_name2id(geom) for geom in self.payload.contact_geoms]
+ self.trash_geom_ids = [self.sim.model.geom_name2id(geom) for geom in self.trash.contact_geoms]
+ self.target_bin_base_geom_ids = [self.sim.model.geom_name2id(geom) for geom in self.target_bin.base_geoms]
+ self.trash_bin_base_geom_ids = [self.sim.model.geom_name2id(geom) for geom in self.trash_bin.base_geoms]
+ self.lid_handle_geom_ids = [self.sim.model.geom_name2id(geom) for geom in self.lid.handle_geoms]
+ self.payload_body_id = self.sim.model.body_name2id(self.payload.root_body)
+ self.trash_body_id = self.sim.model.body_name2id(self.trash.root_body)
+
+ @property
+ def lid_handle_pos(self):
+ """
+ Returns:
+ np.array: (x,y,z) absolute position of the lid handle
+ """
+ return np.array(self.sim.data.geom_xpos[self.lid_handle_geom_ids[0]])
+
+ @property
+ def lid_handle_quat(self):
+ """
+ Returns:
+ np.array: (x,y,z,w) quaternion of the lid handle
+ """
+ return np.array(T.mat2quat(self.sim.data.geom_xmat[self.lid_handle_geom_ids[0]].reshape(3, 3)))
+
+ @property
+ def payload_pos(self):
+ """
+ Returns:
+ np.array: (x,y,z) absolute position of the payload
+ """
+ return np.array(self.sim.data.body_xpos[self.payload_body_id])
+
+ @property
+ def payload_quat(self):
+ """
+ Returns:
+ np.array: (x,y,z,w) quaternion of the payload
+ """
+ return np.array(T.mat2quat(self.sim.data.body_xmat[self.payload_body_id].reshape(3, 3)))
+
+ @property
+ def trash_pos(self):
+ """
+ Returns:
+ np.array: (x,y,z) absolute position of the trash
+ """
+ return np.array(self.sim.data.body_xpos[self.trash_body_id])
+
+ @property
+ def trash_quat(self):
+ """
+ Returns:
+ np.array: (x,y,z,w) quaternion of the trash
+ """
+ return np.array(T.mat2quat(self.sim.data.body_xmat[self.trash_body_id].reshape(3, 3)))
+
+ @property
+ def target_bin_pos(self):
+ """
+ Returns:
+ np.array: (x,y,z) absolute position of the target bin
+ """
+ return np.array(self.sim.data.geom_xpos[self.target_bin_base_geom_ids[0]])
+
+ @property
+ def trash_bin_pos(self):
+ """
+ Returns:
+ np.array: (x,y,z) absolute position of the trash bin
+ """
+ return np.array(self.sim.data.geom_xpos[self.trash_bin_base_geom_ids[0]])
+
+ @property
+ def trash_in_trash_bin(self):
+ """
+ Returns:
+ bool: True if trash is in trash bin
+ """
+ return SU.check_contact(self.sim, self.trash_bin.base_geoms, self.trash.contact_geoms)
+
+ @property
+ def payload_in_target_bin(self):
+ """
+ Returns:
+ bool: True if payload is in target bin
+ """
+ return SU.check_contact(self.sim, self.target_bin.base_geoms, self.payload.contact_geoms)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/object_groups.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/object_groups.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb1100920fbb4a49db4c30bef139870f215dac30
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/object_groups.py
@@ -0,0 +1,46 @@
+class ObjectGroup:
+ """
+ An abstraction that encompasses a group of objects that interact together in a meaningful way
+ name (str): Name of this object group. This will be prepended to all objects generated by this group.
+ """
+
+ def __init__(self, name):
+ # Store internal variables
+ self.name = name
+ self.sim = None # Reference to shared mjsim object
+ self._objects = {} # maps object names to object class instances
+
+ # Generate objects
+ self._generate_objects()
+
+ def get_states(self):
+ """
+ Function to grab group-relevant states. This should be implemented by the subclass.
+ Returns:
+ dict: Keyword-mapped states for this group
+ """
+ raise NotImplementedError
+
+ def update_sim(self, sim):
+ """
+ Updates internal reference to sim and all other relevant references
+ Args:
+ sim (MjSim): Active mujoco sim reference
+ """
+ self.sim = sim
+
+ def _generate_objects(self):
+ """
+ Internal helper function that generates the objects for this group. Should populate self._objects mapping
+ names of objects to their actual object class instances.
+ """
+ raise NotImplementedError
+
+ @property
+ def objects(self):
+ """
+ Contains references to all objects owned by this group. Mapped from names to object instances
+ Returns:
+ dict: keyword-mapped object class instances
+ """
+ return self._objects
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/objects.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b56dfd9cf9088efad337722ac0ea64b78fb2fec
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/objects.py
@@ -0,0 +1,587 @@
+import copy
+import xml.etree.ElementTree as ET
+from copy import deepcopy
+
+import robosuite.macros as macros
+from robosuite.models.base import MujocoModel, MujocoXML
+from robosuite.utils.mjcf_utils import (
+ OBJECT_COLLISION_COLOR,
+ CustomMaterial,
+ add_material,
+ add_prefix,
+ array_to_string,
+ find_elements,
+ new_joint,
+ sort_elements,
+ string_to_array,
+)
+
+# Dict mapping geom type string keywords to group number
+GEOMTYPE2GROUP = {
+ "collision": {0}, # If we want to use a geom for physics, but NOT visualize
+ "visual": {1}, # If we want to use a geom for visualization, but NOT physics
+ "all": {0, 1}, # If we want to use a geom for BOTH physics + visualization
+}
+
+GEOM_GROUPS = GEOMTYPE2GROUP.keys()
+
+
+class MujocoObject(MujocoModel):
+ """
+ Base class for all objects.
+
+ We use Mujoco Objects to implement all objects that:
+
+ 1) may appear for multiple times in a task
+ 2) can be swapped between different tasks
+
+ Typical methods return copy so the caller can all joints/attributes as wanted
+
+ Args:
+ obj_type (str): Geom elements to generate / extract for this object. Must be one of:
+
+ :`'collision'`: Only collision geoms are returned (this corresponds to group 0 geoms)
+ :`'visual'`: Only visual geoms are returned (this corresponds to group 1 geoms)
+ :`'all'`: All geoms are returned
+
+ duplicate_collision_geoms (bool): If set, will guarantee that each collision geom has a
+ visual geom copy
+
+ """
+
+ def __init__(self, obj_type="all", duplicate_collision_geoms=True):
+ super().__init__()
+ self.asset = ET.Element("asset")
+ assert obj_type in GEOM_GROUPS, "object type must be one in {}, got: {} instead.".format(GEOM_GROUPS, obj_type)
+ self.obj_type = obj_type
+ self.duplicate_collision_geoms = duplicate_collision_geoms
+
+ # Attributes that should be filled in within the subclass
+ self._name = None
+ self._obj = None
+
+ # Attributes that are auto-filled by _get_object_properties call
+ self._root_body = None
+ self._bodies = None
+ self._joints = None
+ self._actuators = None
+ self._sites = None
+ self._contact_geoms = None
+ self._visual_geoms = None
+
+ def merge_assets(self, other):
+ """
+ Merges @other's assets in a custom logic.
+
+ Args:
+ other (MujocoXML or MujocoObject): other xml file whose assets will be merged into this one
+ """
+ for asset in other.asset:
+ if (
+ find_elements(root=self.asset, tags=asset.tag, attribs={"name": asset.get("name")}, return_first=True)
+ is None
+ ):
+ self.asset.append(asset)
+
+ def get_obj(self):
+ """
+ Returns the generated / extracted object, in XML ElementTree form.
+
+ Returns:
+ ET.Element: Object in XML form.
+ """
+ assert self._obj is not None, "Object XML tree has not been generated yet!"
+ return self._obj
+
+ def exclude_from_prefixing(self, inp):
+ """
+ A function that should take in either an ET.Element or its attribute (str) and return either True or False,
+ determining whether the corresponding name / str to @inp should have naming_prefix added to it.
+ Must be defined by subclass.
+
+ Args:
+ inp (ET.Element or str): Element or its attribute to check for prefixing.
+
+ Returns:
+ bool: True if we should exclude the associated name(s) with @inp from being prefixed with naming_prefix
+ """
+ raise NotImplementedError
+
+ def _get_object_subtree(self):
+
+ """
+ Returns a ET.Element
+ It is a subtree that defines all collision and / or visualization related fields
+ of this object.
+ Return should be a copy.
+ Must be defined by subclass.
+
+ Returns:
+ ET.Element: body
+ """
+ raise NotImplementedError
+
+ def _get_object_properties(self):
+ """
+ Helper function to extract relevant object properties (bodies, joints, contact/visual geoms, etc...) from this
+ object's XML tree. Assumes the self._obj attribute has already been filled.
+ """
+ # Parse element tree to get all relevant bodies, joints, actuators, and geom groups
+ _elements = sort_elements(root=self.get_obj())
+ assert (
+ len(_elements["root_body"]) == 1
+ ), "Invalid number of root bodies found for robot model. Expected 1," "got {}".format(
+ len(_elements["root_body"])
+ )
+ _elements["root_body"] = _elements["root_body"][0]
+ _elements["bodies"] = (
+ [_elements["root_body"]] + _elements["bodies"] if "bodies" in _elements else [_elements["root_body"]]
+ )
+ self._root_body = _elements["root_body"].get("name")
+ self._bodies = [e.get("name") for e in _elements.get("bodies", [])]
+ self._joints = [e.get("name") for e in _elements.get("joints", [])]
+ self._actuators = [e.get("name") for e in _elements.get("actuators", [])]
+ self._sites = [e.get("name") for e in _elements.get("sites", [])]
+ self._sensors = [e.get("name") for e in _elements.get("sensors", [])]
+ self._contact_geoms = [e.get("name") for e in _elements.get("contact_geoms", [])]
+ self._visual_geoms = [e.get("name") for e in _elements.get("visual_geoms", [])]
+
+ # Add default materials if we're using domain randomization
+ if macros.USING_INSTANCE_RANDOMIZATION:
+ tex_element, mat_element, _, used = add_material(root=self.get_obj(), naming_prefix=self.naming_prefix)
+ # Only add the material / texture if they were actually used
+ if used:
+ self.asset.append(tex_element)
+ self.asset.append(mat_element)
+
+ # Add prefix to all elements
+ add_prefix(root=self.get_obj(), prefix=self.naming_prefix, exclude=self.exclude_from_prefixing)
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def naming_prefix(self):
+ return "{}_".format(self.name)
+
+ @property
+ def root_body(self):
+ return self.correct_naming(self._root_body)
+
+ @property
+ def bodies(self):
+ return self.correct_naming(self._bodies)
+
+ @property
+ def joints(self):
+ return self.correct_naming(self._joints)
+
+ @property
+ def actuators(self):
+ return self.correct_naming(self._actuators)
+
+ @property
+ def sites(self):
+ return self.correct_naming(self._sites)
+
+ @property
+ def sensors(self):
+ return self.correct_naming(self._sensors)
+
+ @property
+ def contact_geoms(self):
+ return self.correct_naming(self._contact_geoms)
+
+ @property
+ def visual_geoms(self):
+ return self.correct_naming(self._visual_geoms)
+
+ @property
+ def important_geoms(self):
+ """
+ Returns:
+ dict: (Default is no important geoms; i.e.: empty dict)
+ """
+ return {}
+
+ @property
+ def important_sites(self):
+ """
+ Returns:
+ dict:
+
+ :`obj`: Object default site
+ """
+ return {"obj": self.naming_prefix + "default_site"}
+
+ @property
+ def important_sensors(self):
+ """
+ Returns:
+ dict: (Default is no sensors; i.e.: empty dict)
+ """
+ return {}
+
+ @property
+ def bottom_offset(self):
+ """
+ Returns vector from model root body to model bottom.
+ Useful for, e.g. placing models on a surface.
+ Must be defined by subclass.
+
+ Returns:
+ np.array: (dx, dy, dz) offset vector
+ """
+ raise NotImplementedError
+
+ @property
+ def top_offset(self):
+ """
+ Returns vector from model root body to model top.
+ Useful for, e.g. placing models on a surface.
+ Must be defined by subclass.
+
+ Returns:
+ np.array: (dx, dy, dz) offset vector
+ """
+ raise NotImplementedError
+
+ @property
+ def horizontal_radius(self):
+ """
+ Returns maximum distance from model root body to any radial point of the model.
+
+ Helps us put models programmatically without them flying away due to a huge initial contact force.
+ Must be defined by subclass.
+
+ Returns:
+ float: radius
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def get_site_attrib_template():
+ """
+ Returns attribs of spherical site used to mark body origin
+
+ Returns:
+ dict: Dictionary of default site attributes
+ """
+ return {
+ "pos": "0 0 0",
+ "size": "0.002 0.002 0.002",
+ "rgba": "1 0 0 1",
+ "type": "sphere",
+ "group": "0",
+ }
+
+ @staticmethod
+ def get_joint_attrib_template():
+ """
+ Returns attribs of free joint
+
+ Returns:
+ dict: Dictionary of default joint attributes
+ """
+ return {
+ "type": "free",
+ }
+
+ def get_bounding_box_half_size(self):
+ raise NotImplementedError
+
+ def get_bounding_box_size(self):
+ """
+ Returns numpy array with dimensions of a bounding box around this object.
+ """
+ return 2. * self.get_bounding_box_half_size()
+
+
+class MujocoXMLObject(MujocoObject, MujocoXML):
+ """
+ MujocoObjects that are loaded from xml files (by default, inherit all properties (e.g.: name)
+ from MujocoObject class first!)
+
+ Args:
+ fname (str): XML File path
+
+ name (str): Name of this MujocoXMLObject
+
+ joints (None or str or list of dict): each dictionary corresponds to a joint that will be created for this
+ object. The dictionary should specify the joint attributes (type, pos, etc.) according to the MuJoCo xml
+ specification. If "default", a single free-joint will be automatically generated. If None, no joints will
+ be created.
+
+ obj_type (str): Geom elements to generate / extract for this object. Must be one of:
+
+ :`'collision'`: Only collision geoms are returned (this corresponds to group 0 geoms)
+ :`'visual'`: Only visual geoms are returned (this corresponds to group 1 geoms)
+ :`'all'`: All geoms are returned
+
+ duplicate_collision_geoms (bool): If set, will guarantee that each collision geom has a
+ visual geom copy
+ """
+
+ def __init__(self, fname, name, joints="default", obj_type="all", duplicate_collision_geoms=True):
+ MujocoXML.__init__(self, fname)
+ # Set obj type and duplicate args
+ assert obj_type in GEOM_GROUPS, "object type must be one in {}, got: {} instead.".format(GEOM_GROUPS, obj_type)
+ self.obj_type = obj_type
+ self.duplicate_collision_geoms = duplicate_collision_geoms
+
+ # Set name
+ self._name = name
+
+ # joints for this object
+ if joints == "default":
+ self.joint_specs = [self.get_joint_attrib_template()] # default free joint
+ elif joints is None:
+ self.joint_specs = []
+ else:
+ self.joint_specs = joints
+
+ # Make sure all joints have names!
+ for i, joint_spec in enumerate(self.joint_specs):
+ if "name" not in joint_spec:
+ joint_spec["name"] = "joint{}".format(i)
+
+ # Lastly, parse XML tree appropriately
+ self._obj = self._get_object_subtree()
+
+ # Extract the appropriate private attributes for this
+ self._get_object_properties()
+
+ def _get_object_subtree(self):
+ # Parse object
+ obj = copy.deepcopy(self.worldbody.find("./body/body[@name='object']"))
+ # Rename this top level object body (will have self.naming_prefix added later)
+ obj.attrib["name"] = "main"
+ # Get all geom_pairs in this tree
+ geom_pairs = self._get_geoms(obj)
+
+ # Define a temp function so we don't duplicate so much code
+ obj_type = self.obj_type
+
+ def _should_keep(el):
+ return int(el.get("group")) in GEOMTYPE2GROUP[obj_type]
+
+ # Loop through each of these pairs and modify them according to @elements arg
+ for i, (parent, element) in enumerate(geom_pairs):
+ # Delete non-relevant geoms and rename remaining ones
+ if not _should_keep(element):
+ parent.remove(element)
+ else:
+ g_name = element.get("name")
+ g_name = g_name if g_name is not None else f"g{i}"
+ element.set("name", g_name)
+ # Also optionally duplicate collision geoms if requested (and this is a collision geom)
+ if self.duplicate_collision_geoms and element.get("group") in {None, "0"}:
+ parent.append(self._duplicate_visual_from_collision(element))
+ # Also manually set the visual appearances to the original collision model
+ element.set("rgba", array_to_string(OBJECT_COLLISION_COLOR))
+ if element.get("material") is not None:
+ del element.attrib["material"]
+ # add joint(s)
+ for joint_spec in self.joint_specs:
+ obj.append(new_joint(**joint_spec))
+ # Lastly, add a site for this object
+ template = self.get_site_attrib_template()
+ template["rgba"] = "1 0 0 0"
+ template["name"] = "default_site"
+ obj.append(ET.Element("site", attrib=template))
+
+ return obj
+
+ def exclude_from_prefixing(self, inp):
+ """
+ By default, don't exclude any from being prefixed
+ """
+ return False
+
+ def _get_object_properties(self):
+ """
+ Extends the base class method to also add prefixes to all bodies in this object
+ """
+ super()._get_object_properties()
+ add_prefix(root=self.root, prefix=self.naming_prefix, exclude=self.exclude_from_prefixing)
+
+ @staticmethod
+ def _duplicate_visual_from_collision(element):
+ """
+ Helper function to duplicate a geom element to be a visual element. Namely, this corresponds to the
+ following attribute requirements: group=1, conaffinity/contype=0, no mass, name appended with "_visual"
+
+ Args:
+ element (ET.Element): element to duplicate as a visual geom
+
+ Returns:
+ element (ET.Element): duplicated element
+ """
+ # Copy element
+ vis_element = deepcopy(element)
+ # Modify for visual-specific attributes (group=1, conaffinity/contype=0, no mass, update name)
+ vis_element.set("group", "1")
+ vis_element.set("conaffinity", "0")
+ vis_element.set("contype", "0")
+ vis_element.set("mass", "1e-8")
+ vis_element.set("name", vis_element.get("name") + "_visual")
+ return vis_element
+
+ def _get_geoms(self, root, _parent=None):
+ """
+ Helper function to recursively search through element tree starting at @root and returns
+ a list of (parent, child) tuples where the child is a geom element
+
+ Args:
+ root (ET.Element): Root of xml element tree to start recursively searching through
+ _parent (ET.Element): Parent of the root element tree. Should not be used externally; only set
+ during the recursive call
+
+ Returns:
+ list: array of (parent, child) tuples where the child element is a geom type
+ """
+ # Initialize return array
+ geom_pairs = []
+ # If the parent exists and this is a geom element, we add this current (parent, element) combo to the output
+ if _parent is not None and root.tag == "geom":
+ geom_pairs.append((_parent, root))
+ # Loop through all children elements recursively and add to pairs
+ for child in root:
+ geom_pairs += self._get_geoms(child, _parent=root)
+ # Return all found pairs
+ return geom_pairs
+
+ @property
+ def bottom_offset(self):
+ bottom_site = self.worldbody.find("./body/site[@name='{}bottom_site']".format(self.naming_prefix))
+ return string_to_array(bottom_site.get("pos"))
+
+ @property
+ def top_offset(self):
+ top_site = self.worldbody.find("./body/site[@name='{}top_site']".format(self.naming_prefix))
+ return string_to_array(top_site.get("pos"))
+
+ @property
+ def horizontal_radius(self):
+ horizontal_radius_site = self.worldbody.find(
+ "./body/site[@name='{}horizontal_radius_site']".format(self.naming_prefix)
+ )
+ return string_to_array(horizontal_radius_site.get("pos"))[0]
+
+ def get_bounding_box_half_size(self):
+ horizontal_radius_site = self.worldbody.find(
+ "./body/site[@name='{}horizontal_radius_site']".format(self.naming_prefix)
+ )
+ return string_to_array(horizontal_radius_site.get("pos")) - self.bottom_offset
+
+
+class MujocoGeneratedObject(MujocoObject):
+ """
+ Base class for all procedurally generated objects.
+
+ Args:
+ obj_type (str): Geom elements to generate / extract for this object. Must be one of:
+
+ :`'collision'`: Only collision geoms are returned (this corresponds to group 0 geoms)
+ :`'visual'`: Only visual geoms are returned (this corresponds to group 1 geoms)
+ :`'all'`: All geoms are returned
+
+ duplicate_collision_geoms (bool): If set, will guarantee that each collision geom has a
+ visual geom copy
+ """
+
+ def __init__(self, obj_type="all", duplicate_collision_geoms=True):
+ super().__init__(obj_type=obj_type, duplicate_collision_geoms=duplicate_collision_geoms)
+
+ # Store common material names so we don't add prefixes to them
+ self.shared_materials = set()
+ self.shared_textures = set()
+
+ def sanity_check(self):
+ """
+ Checks if data provided makes sense.
+ Called in __init__()
+ For subclasses to inherit from
+ """
+ pass
+
+ @staticmethod
+ def get_collision_attrib_template():
+ """
+ Generates template with collision attributes for a given geom
+
+ Returns:
+ dict: Initial template with `'pos'` and `'group'` already specified
+ """
+ return {"group": "0", "rgba": array_to_string(OBJECT_COLLISION_COLOR)}
+
+ @staticmethod
+ def get_visual_attrib_template():
+ """
+ Generates template with visual attributes for a given geom
+
+ Returns:
+ dict: Initial template with `'conaffinity'`, `'contype'`, and `'group'` already specified
+ """
+ return {"conaffinity": "0", "contype": "0", "mass": "1e-8", "group": "1"}
+
+ def append_material(self, material):
+ """
+ Adds a new texture / material combination to the assets subtree of this XML
+ Input is expected to be a CustomMaterial object
+
+ See http://www.mujoco.org/book/XMLreference.html#asset for specific details on attributes expected for
+ Mujoco texture / material tags, respectively
+
+ Note that the "file" attribute for the "texture" tag should be specified relative to the textures directory
+ located in robosuite/models/assets/textures/
+
+ Args:
+ material (CustomMaterial): Material to add to this object
+ """
+ # First check if asset attribute exists; if not, define the asset attribute
+ if not hasattr(self, "asset"):
+ self.asset = ET.Element("asset")
+ # If the material name is not in shared materials, add this to our assets
+ if material.name not in self.shared_materials:
+ self.asset.append(ET.Element("texture", attrib=material.tex_attrib))
+ self.asset.append(ET.Element("material", attrib=material.mat_attrib))
+ # Add this material name to shared materials if it should be shared
+ if material.shared:
+ self.shared_materials.add(material.name)
+ self.shared_textures.add(material.tex_attrib["name"])
+ # Update prefix for assets
+ add_prefix(root=self.asset, prefix=self.naming_prefix, exclude=self.exclude_from_prefixing)
+
+ def exclude_from_prefixing(self, inp):
+ """
+ Exclude all shared materials and their associated names from being prefixed.
+
+ Args:
+ inp (ET.Element or str): Element or its attribute to check for prefixing.
+
+ Returns:
+ bool: True if we should exclude the associated name(s) with @inp from being prefixed with naming_prefix
+ """
+ # Automatically return False if this is not of type "str"
+ if type(inp) is not str:
+ return False
+ # Only return True if the string matches the name of a common material
+ return True if inp in self.shared_materials or inp in self.shared_textures else False
+
+ # Methods that still need to be defined by subclass
+ def _get_object_subtree(self):
+ raise NotImplementedError
+
+ def bottom_offset(self):
+ raise NotImplementedError
+
+ def top_offset(self):
+ raise NotImplementedError
+
+ def horizontal_radius(self):
+ raise NotImplementedError
+
+ def get_bounding_box_half_size(self):
+ return np.array([self.horizontal_radius, self.horizontal_radius, 0.]) - self.bottom_offset
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b224ac36f77f0c71cc7b6e38fd9b697f75f2823
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/__init__.py
@@ -0,0 +1,4 @@
+from .ball import BallObject
+from .box import BoxObject
+from .capsule import CapsuleObject
+from .cylinder import CylinderObject
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/ball.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/ball.py
new file mode 100644
index 0000000000000000000000000000000000000000..95c6621dedadaf436ecb33dc3720fc8db7ddbacf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/ball.py
@@ -0,0 +1,71 @@
+import numpy as np
+
+from robosuite.models.objects import PrimitiveObject
+from robosuite.utils.mjcf_utils import get_size
+
+
+class BallObject(PrimitiveObject):
+ """
+ A ball (sphere) object.
+
+ Args:
+ size (1-tuple of float): (radius) size parameters for this ball object
+ """
+
+ def __init__(
+ self,
+ name,
+ size=None,
+ size_max=None,
+ size_min=None,
+ density=None,
+ friction=None,
+ rgba=None,
+ solref=None,
+ solimp=None,
+ material=None,
+ joints="default",
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ ):
+ size = get_size(size, size_max, size_min, [0.07], [0.03])
+ super().__init__(
+ name=name,
+ size=size,
+ rgba=rgba,
+ density=density,
+ friction=friction,
+ solref=solref,
+ solimp=solimp,
+ material=material,
+ joints=joints,
+ obj_type=obj_type,
+ duplicate_collision_geoms=duplicate_collision_geoms,
+ )
+
+ def sanity_check(self):
+ """
+ Checks to make sure inputted size is of correct length
+
+ Raises:
+ AssertionError: [Invalid size length]
+ """
+ assert len(self.size) == 1, "ball size should have length 1"
+
+ def _get_object_subtree(self):
+ return self._get_object_subtree_(ob_type="sphere")
+
+ @property
+ def bottom_offset(self):
+ return np.array([0, 0, -1 * self.size[0]])
+
+ @property
+ def top_offset(self):
+ return np.array([0, 0, self.size[0]])
+
+ @property
+ def horizontal_radius(self):
+ return self.size[0]
+
+ def get_bounding_box_half_size(self):
+ return np.array([self.size[0], self.size[0], self.size[0]])
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/box.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/box.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fcb660be972e8862887be7508d16d7bbe6307c8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/box.py
@@ -0,0 +1,71 @@
+import numpy as np
+
+from robosuite.models.objects import PrimitiveObject
+from robosuite.utils.mjcf_utils import get_size
+
+
+class BoxObject(PrimitiveObject):
+ """
+ A box object.
+
+ Args:
+ size (3-tuple of float): (half-x, half-y, half-z) size parameters for this box object
+ """
+
+ def __init__(
+ self,
+ name,
+ size=None,
+ size_max=None,
+ size_min=None,
+ density=None,
+ friction=None,
+ rgba=None,
+ solref=None,
+ solimp=None,
+ material=None,
+ joints="default",
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ ):
+ size = get_size(size, size_max, size_min, [0.07, 0.07, 0.07], [0.03, 0.03, 0.03])
+ super().__init__(
+ name=name,
+ size=size,
+ rgba=rgba,
+ density=density,
+ friction=friction,
+ solref=solref,
+ solimp=solimp,
+ material=material,
+ joints=joints,
+ obj_type=obj_type,
+ duplicate_collision_geoms=duplicate_collision_geoms,
+ )
+
+ def sanity_check(self):
+ """
+ Checks to make sure inputted size is of correct length
+
+ Raises:
+ AssertionError: [Invalid size length]
+ """
+ assert len(self.size) == 3, "box size should have length 3"
+
+ def _get_object_subtree(self):
+ return self._get_object_subtree_(ob_type="box")
+
+ @property
+ def bottom_offset(self):
+ return np.array([0, 0, -1 * self.size[2]])
+
+ @property
+ def top_offset(self):
+ return np.array([0, 0, self.size[2]])
+
+ @property
+ def horizontal_radius(self):
+ return np.linalg.norm(self.size[0:2], 2)
+
+ def get_bounding_box_half_size(self):
+ return np.array([self.size[0], self.size[1], self.size[2]])
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/capsule.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/capsule.py
new file mode 100644
index 0000000000000000000000000000000000000000..6139cb3473d1d658a169f5620775722e75581929
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/capsule.py
@@ -0,0 +1,71 @@
+import numpy as np
+
+from robosuite.models.objects import PrimitiveObject
+from robosuite.utils.mjcf_utils import get_size
+
+
+class CapsuleObject(PrimitiveObject):
+ """
+ A capsule object.
+
+ Args:
+ size (2-tuple of float): (radius, half-length) size parameters for this capsule object
+ """
+
+ def __init__(
+ self,
+ name,
+ size=None,
+ size_max=None,
+ size_min=None,
+ density=None,
+ friction=None,
+ rgba=None,
+ solref=None,
+ solimp=None,
+ material=None,
+ joints="default",
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ ):
+ size = get_size(size, size_max, size_min, [0.07, 0.07], [0.03, 0.03])
+ super().__init__(
+ name=name,
+ size=size,
+ rgba=rgba,
+ density=density,
+ friction=friction,
+ solref=solref,
+ solimp=solimp,
+ material=material,
+ joints=joints,
+ obj_type=obj_type,
+ duplicate_collision_geoms=duplicate_collision_geoms,
+ )
+
+ def sanity_check(self):
+ """
+ Checks to make sure inputted size is of correct length
+
+ Raises:
+ AssertionError: [Invalid size length]
+ """
+ assert len(self.size) == 2, "capsule size should have length 2"
+
+ def _get_object_subtree(self):
+ return self._get_object_subtree_(ob_type="capsule")
+
+ @property
+ def bottom_offset(self):
+ return np.array([0, 0, -1 * (self.size[0] + self.size[1])])
+
+ @property
+ def top_offset(self):
+ return np.array([0, 0, (self.size[0] + self.size[1])])
+
+ @property
+ def horizontal_radius(self):
+ return self.size[0]
+
+ def get_bounding_box_half_size(self):
+ return np.array([self.size[0], self.size[0], self.size[0] + self.size[1]])
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/cylinder.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/cylinder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e2dc9af1d6ab3a4c10226b141f6a181437cd44a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/primitive/cylinder.py
@@ -0,0 +1,95 @@
+import numpy as np
+
+from robosuite.models.objects import MujocoGeneratedObject, PrimitiveObject
+from robosuite.utils.mjcf_utils import get_size
+
+
+class CylinderObject(PrimitiveObject):
+ """
+ A cylinder object.
+
+ Args:
+ size (2-tuple of float): (radius, half-length) size parameters for this cylinder object
+ """
+
+ def __init__(
+ self,
+ name,
+ size=None,
+ size_max=None,
+ size_min=None,
+ density=None,
+ friction=None,
+ rgba=None,
+ solref=None,
+ solimp=None,
+ material=None,
+ joints="default",
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ ):
+ size = get_size(size, size_max, size_min, [0.07, 0.07], [0.03, 0.03])
+
+ # We override solref, solimp, and joint default values for better stability
+ if friction is None:
+ friction = [1, 0.01, 0.001]
+ if solref is None:
+ solref = [0.01, 0.5]
+ if joints == "default":
+ joints = [{"type": "free", "damping": "0.0001"}]
+
+ super().__init__(
+ name=name,
+ size=size,
+ rgba=rgba,
+ density=density,
+ friction=friction,
+ solref=solref,
+ solimp=solimp,
+ material=material,
+ joints=joints,
+ obj_type=obj_type,
+ duplicate_collision_geoms=duplicate_collision_geoms,
+ )
+
+ def sanity_check(self):
+ """
+ Checks to make sure inputted size is of correct length
+
+ Raises:
+ AssertionError: [Invalid size length]
+ """
+ assert len(self.size) == 2, "cylinder size should have length 2"
+
+ def _get_object_subtree(self):
+ return self._get_object_subtree_(ob_type="cylinder")
+
+ @staticmethod
+ def get_collision_attrib_template():
+ """
+ Generates template with collision attributes for a given geom
+
+ Extends super method for better stability for contacts
+
+ Returns:
+ dict: Initial template with `'pos'` and `'group'` already specified
+ """
+ template = MujocoGeneratedObject.get_collision_attrib_template()
+ # Add condim value
+ template["margin"] = "0.001"
+ return template
+
+ @property
+ def bottom_offset(self):
+ return np.array([0, 0, -1 * self.size[1]])
+
+ @property
+ def top_offset(self):
+ return np.array([0, 0, self.size[1]])
+
+ @property
+ def horizontal_radius(self):
+ return self.size[0]
+
+ def get_bounding_box_half_size(self):
+ return np.array([self.size[0], self.size[0], self.size[1]])
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/objects/xml_objects.py b/phantom/submodules/phantom-robosuite/robosuite/models/objects/xml_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..68e0369d469d33e8052e6e0157182bb65e0b194d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/objects/xml_objects.py
@@ -0,0 +1,299 @@
+import numpy as np
+
+from robosuite.models.objects import MujocoXMLObject
+from robosuite.utils.mjcf_utils import array_to_string, find_elements, xml_path_completion
+
+
+class BottleObject(MujocoXMLObject):
+ """
+ Bottle object
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/bottle.xml"),
+ name=name,
+ joints=[dict(type="free", damping="0.0005")],
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ )
+
+
+class CanObject(MujocoXMLObject):
+ """
+ Coke can object (used in PickPlace)
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/can.xml"),
+ name=name,
+ joints=[dict(type="free", damping="0.0005")],
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ )
+
+
+class LemonObject(MujocoXMLObject):
+ """
+ Lemon object
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/lemon.xml"), name=name, obj_type="all", duplicate_collision_geoms=True
+ )
+
+
+class MilkObject(MujocoXMLObject):
+ """
+ Milk carton object (used in PickPlace)
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/milk.xml"),
+ name=name,
+ joints=[dict(type="free", damping="0.0005")],
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ )
+
+
+class BreadObject(MujocoXMLObject):
+ """
+ Bread loaf object (used in PickPlace)
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/bread.xml"),
+ name=name,
+ joints=[dict(type="free", damping="0.0005")],
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ )
+
+
+class CerealObject(MujocoXMLObject):
+ """
+ Cereal box object (used in PickPlace)
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/cereal.xml"),
+ name=name,
+ joints=[dict(type="free", damping="0.0005")],
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ )
+
+
+class SquareNutObject(MujocoXMLObject):
+ """
+ Square nut object (used in NutAssembly)
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/square-nut.xml"),
+ name=name,
+ joints=[dict(type="free", damping="0.0005")],
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ )
+
+ @property
+ def important_sites(self):
+ """
+ Returns:
+ dict: In addition to any default sites for this object, also provides the following entries
+
+ :`'handle'`: Name of nut handle location site
+ """
+ # Get dict from super call and add to it
+ dic = super().important_sites
+ dic.update({"handle": self.naming_prefix + "handle_site"})
+ return dic
+
+
+class RoundNutObject(MujocoXMLObject):
+ """
+ Round nut (used in NutAssembly)
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/round-nut.xml"),
+ name=name,
+ joints=[dict(type="free", damping="0.0005")],
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ )
+
+ @property
+ def important_sites(self):
+ """
+ Returns:
+ dict: In addition to any default sites for this object, also provides the following entries
+
+ :`'handle'`: Name of nut handle location site
+ """
+ # Get dict from super call and add to it
+ dic = super().important_sites
+ dic.update({"handle": self.naming_prefix + "handle_site"})
+ return dic
+
+
+class MilkVisualObject(MujocoXMLObject):
+ """
+ Visual fiducial of milk carton (used in PickPlace).
+
+ Fiducial objects are not involved in collision physics.
+ They provide a point of reference to indicate a position.
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/milk-visual.xml"),
+ name=name,
+ joints=None,
+ obj_type="visual",
+ duplicate_collision_geoms=True,
+ )
+
+
+class BreadVisualObject(MujocoXMLObject):
+ """
+ Visual fiducial of bread loaf (used in PickPlace)
+
+ Fiducial objects are not involved in collision physics.
+ They provide a point of reference to indicate a position.
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/bread-visual.xml"),
+ name=name,
+ joints=None,
+ obj_type="visual",
+ duplicate_collision_geoms=True,
+ )
+
+
+class CerealVisualObject(MujocoXMLObject):
+ """
+ Visual fiducial of cereal box (used in PickPlace)
+
+ Fiducial objects are not involved in collision physics.
+ They provide a point of reference to indicate a position.
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/cereal-visual.xml"),
+ name=name,
+ joints=None,
+ obj_type="visual",
+ duplicate_collision_geoms=True,
+ )
+
+
+class CanVisualObject(MujocoXMLObject):
+ """
+ Visual fiducial of coke can (used in PickPlace)
+
+ Fiducial objects are not involved in collision physics.
+ They provide a point of reference to indicate a position.
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/can-visual.xml"),
+ name=name,
+ joints=None,
+ obj_type="visual",
+ duplicate_collision_geoms=True,
+ )
+
+
+class PlateWithHoleObject(MujocoXMLObject):
+ """
+ Square plate with a hole in the center (used in PegInHole)
+ """
+
+ def __init__(self, name):
+ super().__init__(
+ xml_path_completion("objects/plate-with-hole.xml"),
+ name=name,
+ joints=None,
+ obj_type="all",
+ duplicate_collision_geoms=True,
+ )
+
+
+class DoorObject(MujocoXMLObject):
+ """
+ Door with handle (used in Door)
+
+ Args:
+ friction (3-tuple of float): friction parameters to override the ones specified in the XML
+ damping (float): damping parameter to override the ones specified in the XML
+ lock (bool): Whether to use the locked door variation object or not
+ """
+
+ def __init__(self, name, friction=None, damping=None, lock=False):
+ xml_path = "objects/door.xml"
+ if lock:
+ xml_path = "objects/door_lock.xml"
+ super().__init__(
+ xml_path_completion(xml_path), name=name, joints=None, obj_type="all", duplicate_collision_geoms=True
+ )
+
+ # Set relevant body names
+ self.door_body = self.naming_prefix + "door"
+ self.frame_body = self.naming_prefix + "frame"
+ self.latch_body = self.naming_prefix + "latch"
+ self.hinge_joint = self.naming_prefix + "hinge"
+
+ self.lock = lock
+ self.friction = friction
+ self.damping = damping
+ if self.friction is not None:
+ self._set_door_friction(self.friction)
+ if self.damping is not None:
+ self._set_door_damping(self.damping)
+
+ def _set_door_friction(self, friction):
+ """
+ Helper function to override the door friction directly in the XML
+
+ Args:
+ friction (3-tuple of float): friction parameters to override the ones specified in the XML
+ """
+ hinge = find_elements(root=self.worldbody, tags="joint", attribs={"name": self.hinge_joint}, return_first=True)
+ hinge.set("frictionloss", array_to_string(np.array([friction])))
+
+ def _set_door_damping(self, damping):
+ """
+ Helper function to override the door friction directly in the XML
+
+ Args:
+ damping (float): damping parameter to override the ones specified in the XML
+ """
+ hinge = find_elements(root=self.worldbody, tags="joint", attribs={"name": self.hinge_joint}, return_first=True)
+ hinge.set("damping", array_to_string(np.array([damping])))
+
+ @property
+ def important_sites(self):
+ """
+ Returns:
+ dict: In addition to any default sites for this object, also provides the following entries
+
+ :`'handle'`: Name of door handle location site
+ """
+ # Get dict from super call and add to it
+ dic = super().important_sites
+ dic.update({"handle": self.naming_prefix + "handle"})
+ return dic
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..79787df25e2de33708dba56ee2ffa30c71af5b00
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/__init__.py
@@ -0,0 +1,2 @@
+from .robot_model import RobotModel, create_robot
+from .manipulators import *
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6d2c177d05d0cf19aa37d13e723af97be9eea15
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/__init__.py
@@ -0,0 +1,8 @@
+from .manipulator_model import ManipulatorModel
+from .sawyer_robot import Sawyer
+from .baxter_robot import Baxter
+from .panda_robot import Panda
+from .jaco_robot import Jaco
+from .kinova3_robot import Kinova3
+from .iiwa_robot import IIWA
+from .ur5e_robot import UR5e
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/baxter_robot.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/baxter_robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..631dba7e16da4e5fac42719c9bb81ec5a7f3f254
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/baxter_robot.py
@@ -0,0 +1,89 @@
+import numpy as np
+
+from robosuite.models.robots.manipulators.manipulator_model import ManipulatorModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class Baxter(ManipulatorModel):
+ """
+ Baxter is a hunky bimanual robot designed by Rethink Robotics.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this robot instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("robots/baxter/robot.xml"), idn=idn)
+
+ @property
+ def default_mount(self):
+ return "RethinkMinimalMount"
+
+ @property
+ def default_gripper(self):
+ """
+ Since this is bimanual robot, returns dict with `'right'`, `'left'` keywords corresponding to their respective
+ values
+
+ Returns:
+ dict: Dictionary containing arm-specific gripper names
+ """
+ return {"right": "RethinkGripper", "left": "RethinkGripper"}
+
+ @property
+ def default_controller_config(self):
+ """
+ Since this is bimanual robot, returns dict with `'right'`, `'left'` keywords corresponding to their respective
+ values
+
+ Returns:
+ dict: Dictionary containing arm-specific default controller config names
+ """
+ return {"right": "default_baxter", "left": "default_baxter"}
+
+ @property
+ def init_qpos(self):
+ """
+ Since this is bimanual robot, returns [right, left] array corresponding to respective values
+
+ Note that this is a pose such that the arms are half extended
+
+ Returns:
+ np.array: default initial qpos for the right, left arms
+ """
+ # [right, left]
+ # Arms half extended
+ return np.array(
+ [0.403, -0.636, 0.114, 1.432, 0.735, 1.205, -0.269, -0.403, -0.636, -0.114, 1.432, -0.735, 1.205, 0.269]
+ )
+
+ @property
+ def base_xpos_offset(self):
+ return {
+ "bins": (-0.5, -0.1, 0),
+ "empty": (-0.29, 0, 0),
+ "table": lambda table_length: (-0.26 - table_length / 2, 0, 0),
+ }
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, 1.0))
+
+ @property
+ def _horizontal_radius(self):
+ return 0.5
+
+ @property
+ def arm_type(self):
+ return "bimanual"
+
+ @property
+ def _eef_name(self):
+ """
+ Since this is bimanual robot, returns dict with `'right'`, `'left'` keywords corresponding to their respective
+ values
+
+ Returns:
+ dict: Dictionary containing arm-specific eef names
+ """
+ return {"right": "right_hand", "left": "left_hand"}
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/iiwa_robot.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/iiwa_robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeccb10f79ecbf5eafd7347538cfc8900b2a0a80
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/iiwa_robot.py
@@ -0,0 +1,52 @@
+import numpy as np
+
+from robosuite.models.robots.manipulators.manipulator_model import ManipulatorModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class IIWA(ManipulatorModel):
+ """
+ IIWA is a bright and spunky robot created by KUKA
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this robot instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("robots/iiwa/robot.xml"), idn=idn)
+
+ @property
+ def default_mount(self):
+ return "RethinkMount"
+
+ @property
+ def default_gripper(self):
+ return "Robotiq140Gripper"
+
+ @property
+ def default_controller_config(self):
+ return "default_iiwa"
+
+ @property
+ def init_qpos(self):
+ return np.array([0.000, 0.650, 0.000, -1.890, 0.000, 0.600, 0.000])
+
+ @property
+ def base_xpos_offset(self):
+ return {
+ "bins": (-0.5, -0.1, 0),
+ "empty": (-0.6, 0, 0),
+ "table": lambda table_length: (-0.16 - table_length / 2, 0, 0),
+ }
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, 1.0))
+
+ @property
+ def _horizontal_radius(self):
+ return 0.5
+
+ @property
+ def arm_type(self):
+ return "single"
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/jaco_robot.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/jaco_robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..092d431f34ccc810c4f27400629edb787b70a8a8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/jaco_robot.py
@@ -0,0 +1,52 @@
+import numpy as np
+
+from robosuite.models.robots.manipulators.manipulator_model import ManipulatorModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class Jaco(ManipulatorModel):
+ """
+ Jaco is a kind and assistive robot created by Kinova
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this robot instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("robots/jaco/robot.xml"), idn=idn)
+
+ @property
+ def default_mount(self):
+ return "RethinkMount"
+
+ @property
+ def default_gripper(self):
+ return "JacoThreeFingerGripper"
+
+ @property
+ def default_controller_config(self):
+ return "default_jaco"
+
+ @property
+ def init_qpos(self):
+ return np.array([3.192, 3.680, -0.000, 1.170, 0.050, 3.760, 3.142])
+
+ @property
+ def base_xpos_offset(self):
+ return {
+ "bins": (-0.5, -0.1, 0),
+ "empty": (-0.6, 0, 0),
+ "table": lambda table_length: (-0.16 - table_length / 2, 0, 0),
+ }
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, 1.0))
+
+ @property
+ def _horizontal_radius(self):
+ return 0.5
+
+ @property
+ def arm_type(self):
+ return "single"
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/kinova3_robot.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/kinova3_robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a3835c78c39dc3d99955e5abdaa023e1ed3430
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/kinova3_robot.py
@@ -0,0 +1,52 @@
+import numpy as np
+
+from robosuite.models.robots.manipulators.manipulator_model import ManipulatorModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class Kinova3(ManipulatorModel):
+ """
+ The Gen3 robot is the sparkly newest addition to the Kinova line
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this robot instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("robots/kinova3/robot.xml"), idn=idn)
+
+ @property
+ def default_mount(self):
+ return "PhantomMount"
+
+ @property
+ def default_gripper(self):
+ return "Robotiq85Gripper"
+
+ @property
+ def default_controller_config(self):
+ return "default_kinova3"
+
+ @property
+ def init_qpos(self):
+ return np.array([0.000, 0.650, 0.000, 1.890, 0.000, 0.600, -np.pi / 2])
+
+ @property
+ def base_xpos_offset(self):
+ return {
+ "bins": (-0.5, -0.1, 0),
+ "empty": (-0.6, 0, 0),
+ "table": lambda table_length: (-0.16 - table_length / 2, 0, 0),
+ }
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, 1.0))
+
+ @property
+ def _horizontal_radius(self):
+ return 0.5
+
+ @property
+ def arm_type(self):
+ return "single"
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/manipulator_model.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/manipulator_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7197d69f15212b4fd02ebffca71fb98b7fda617a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/manipulator_model.py
@@ -0,0 +1,177 @@
+from collections import OrderedDict
+
+import numpy as np
+
+from robosuite.models.robots import RobotModel
+from robosuite.utils.mjcf_utils import find_elements, string_to_array
+
+
+class ManipulatorModel(RobotModel):
+ """
+ Base class for all manipulator models (robot arm(s) with gripper(s)).
+
+ Args:
+ fname (str): Path to relevant xml file from which to create this robot instance
+ idn (int or str): Number or some other unique identification string for this robot instance
+ """
+
+ def __init__(self, fname, idn=0):
+ # Always run super init first
+ super().__init__(fname, idn=idn)
+
+ # key: gripper name and value: gripper model
+ self.grippers = OrderedDict()
+
+ # Grab hand's offset from final robot link (string -> np.array -> elements [1, 2, 3, 0] (x, y, z, w))
+ # Different case based on whether we're dealing with single or bimanual armed robot
+ if self.arm_type == "single":
+ hand_element = find_elements(
+ root=self.root, tags="body", attribs={"name": self.eef_name}, return_first=True
+ )
+ self.hand_rotation_offset = string_to_array(hand_element.get("quat", "1 0 0 0"))[[1, 2, 3, 0]]
+ else: # "bimanual" case
+ self.hand_rotation_offset = {}
+ for arm in ("right", "left"):
+ hand_element = find_elements(
+ root=self.root, tags="body", attribs={"name": self.eef_name[arm]}, return_first=True
+ )
+ self.hand_rotation_offset[arm] = string_to_array(hand_element.get("quat", "1 0 0 0"))[[1, 2, 3, 0]]
+
+ # Get camera names for this robot
+ self.cameras = self.get_element_names(self.worldbody, "camera")
+
+ def add_gripper(self, gripper, arm_name=None):
+ """
+ Mounts @gripper to arm.
+
+ Throws error if robot already has a gripper or gripper type is incorrect.
+
+ Args:
+ gripper (GripperModel): gripper MJCF model
+ arm_name (str): name of arm mount -- defaults to self.eef_name if not specified
+
+ Raises:
+ ValueError: [Multiple grippers]
+ """
+ if arm_name is None:
+ arm_name = self.eef_name
+ if arm_name in self.grippers:
+ raise ValueError("Attempts to add multiple grippers to one body")
+
+ self.merge(gripper, merge_body=arm_name)
+
+ self.grippers[arm_name] = gripper
+
+ # Update cameras in this model
+ self.cameras = self.get_element_names(self.worldbody, "camera")
+
+ # -------------------------------------------------------------------------------------- #
+ # Public Properties: In general, these are the name-adjusted versions from the private #
+ # attributes pulled from their respective raw xml files #
+ # -------------------------------------------------------------------------------------- #
+
+ @property
+ def eef_name(self):
+ """
+ Returns:
+ str or dict of str: Prefix-adjusted eef name for this robot. If bimanual robot, returns {"left", "right"}
+ keyword-mapped eef names
+ """
+ return self.correct_naming(self._eef_name)
+
+ @property
+ def models(self):
+ """
+ Returns a list of all m(sub-)models owned by this robot model. By default, this includes the gripper model,
+ if specified
+
+ Returns:
+ list: models owned by this object
+ """
+ models = super().models
+ return models + list(self.grippers.values())
+
+ # -------------------------------------------------------------------------------------- #
+ # -------------------------- Private Properties ---------------------------------------- #
+ # -------------------------------------------------------------------------------------- #
+
+ @property
+ def _important_sites(self):
+ """
+ Returns:
+ dict: (Default is no important sites; i.e.: empty dict)
+ """
+ return {}
+
+ @property
+ def _eef_name(self):
+ """
+ XML eef name for this robot to which grippers can be attached. Note that these should be the raw
+ string names directly pulled from a robot's corresponding XML file, NOT the adjusted name with an
+ auto-generated naming prefix
+
+ Returns:
+ str: Raw XML eef name for this robot (default is "right_hand")
+ """
+ return "right_hand"
+
+ # -------------------------------------------------------------------------------------- #
+ # All subclasses must implement the following properties #
+ # -------------------------------------------------------------------------------------- #
+
+ @property
+ def default_gripper(self):
+ """
+ Defines the default gripper type for this robot that gets added to end effector
+
+ Returns:
+ str: Default gripper name to add to this robot
+ """
+ raise NotImplementedError
+
+ @property
+ def arm_type(self):
+ """
+ Type of robot arm. Should be either "bimanual" or "single" (or something else if it gets added in the future)
+
+ Returns:
+ str: Type of robot
+ """
+ raise NotImplementedError
+
+ @property
+ def base_xpos_offset(self):
+ """
+ Defines the dict of various (x,y,z) tuple offsets relative to specific arenas placed at (0,0,0)
+ Assumes robot is facing forwards (in the +x direction) when determining offset. Should have entries for each
+ manipulator arena case; i.e.: "bins", "empty", and "table")
+
+ Returns:
+ dict:
+
+ :`'bins'`: (x,y,z) robot offset if placed in bins arena
+ :`'empty'`: (x,y,z) robot offset if placed in the empty arena
+ :`'table'`: lambda function that takes in table_length and returns corresponding (x,y,z) offset
+ if placed in the table arena
+ """
+ raise NotImplementedError
+
+ @property
+ def top_offset(self):
+ raise NotImplementedError
+
+ @property
+ def _horizontal_radius(self):
+ raise NotImplementedError
+
+ @property
+ def default_mount(self):
+ raise NotImplementedError
+
+ @property
+ def default_controller_config(self):
+ raise NotImplementedError
+
+ @property
+ def init_qpos(self):
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/panda_robot.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/panda_robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..336440da8de6d33b52089a1ffb8f58f3e5e17db8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/panda_robot.py
@@ -0,0 +1,55 @@
+import numpy as np
+
+from robosuite.models.robots.manipulators.manipulator_model import ManipulatorModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class Panda(ManipulatorModel):
+ """
+ Panda is a sensitive single-arm robot designed by Franka.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this robot instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("robots/panda/robot.xml"), idn=idn)
+
+ # Set joint damping
+ self.set_joint_attribute(attrib="damping", values=np.array((0.1, 0.1, 0.1, 0.1, 0.1, 0.01, 0.01)))
+
+ @property
+ def default_mount(self):
+ return "RethinkMount"
+
+ @property
+ def default_gripper(self):
+ return "PandaGripper"
+
+ @property
+ def default_controller_config(self):
+ return "default_panda"
+
+ @property
+ def init_qpos(self):
+ return np.array([0, np.pi / 16.0, 0.00, -np.pi / 2.0 - np.pi / 3.0, 0.00, np.pi - 0.2, np.pi / 4])
+
+ @property
+ def base_xpos_offset(self):
+ return {
+ "bins": (-0.5, -0.1, 0),
+ "empty": (-0.6, 0, 0),
+ "table": lambda table_length: (-0.16 - table_length / 2, 0, 0),
+ }
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, 1.0))
+
+ @property
+ def _horizontal_radius(self):
+ return 0.5
+
+ @property
+ def arm_type(self):
+ return "single"
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/sawyer_robot.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/sawyer_robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc9c4e763b09aa3e750f88e6493975edfe66b6ed
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/sawyer_robot.py
@@ -0,0 +1,52 @@
+import numpy as np
+
+from robosuite.models.robots.manipulators.manipulator_model import ManipulatorModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class Sawyer(ManipulatorModel):
+ """
+ Sawyer is a witty single-arm robot designed by Rethink Robotics.
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this robot instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("robots/sawyer/robot.xml"), idn=idn)
+
+ @property
+ def default_mount(self):
+ return "RethinkMount"
+
+ @property
+ def default_gripper(self):
+ return "RethinkGripper"
+
+ @property
+ def default_controller_config(self):
+ return "default_sawyer"
+
+ @property
+ def init_qpos(self):
+ return np.array([0, -1.18, 0.00, 2.18, 0.00, 0.57, -1.57])
+
+ @property
+ def base_xpos_offset(self):
+ return {
+ "bins": (-0.5, -0.1, 0),
+ "empty": (-0.6, 0, 0),
+ "table": lambda table_length: (-0.16 - table_length / 2, 0, 0),
+ }
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, 1.0))
+
+ @property
+ def _horizontal_radius(self):
+ return 0.5
+
+ @property
+ def arm_type(self):
+ return "single"
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/ur5e_robot.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/ur5e_robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecd7a48b3c33fd94fd72187244907f93f74b1dca
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/manipulators/ur5e_robot.py
@@ -0,0 +1,52 @@
+import numpy as np
+
+from robosuite.models.robots.manipulators.manipulator_model import ManipulatorModel
+from robosuite.utils.mjcf_utils import xml_path_completion
+
+
+class UR5e(ManipulatorModel):
+ """
+ UR5e is a sleek and elegant new robot created by Universal Robots
+
+ Args:
+ idn (int or str): Number or some other unique identification string for this robot instance
+ """
+
+ def __init__(self, idn=0):
+ super().__init__(xml_path_completion("robots/ur5e/robot.xml"), idn=idn)
+
+ @property
+ def default_mount(self):
+ return "RethinkMount"
+
+ @property
+ def default_gripper(self):
+ return "Robotiq85Gripper"
+
+ @property
+ def default_controller_config(self):
+ return "default_ur5e"
+
+ @property
+ def init_qpos(self):
+ return np.array([-0.470, -1.735, 2.480, -2.275, -1.590, -1.991])
+
+ @property
+ def base_xpos_offset(self):
+ return {
+ "bins": (-0.5, -0.1, 0),
+ "empty": (-0.6, 0, 0),
+ "table": lambda table_length: (-0.16 - table_length / 2, 0, 0),
+ }
+
+ @property
+ def top_offset(self):
+ return np.array((0, 0, 1.0))
+
+ @property
+ def _horizontal_radius(self):
+ return 0.5
+
+ @property
+ def arm_type(self):
+ return "single"
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/robots/robot_model.py b/phantom/submodules/phantom-robosuite/robosuite/models/robots/robot_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..63758c2094dd7dd97c6d551e708cfc0bab9d9b98
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/robots/robot_model.py
@@ -0,0 +1,298 @@
+import numpy as np
+
+from robosuite.models.base import MujocoXMLModel
+from robosuite.utils.mjcf_utils import ROBOT_COLLISION_COLOR, array_to_string, string_to_array
+from robosuite.utils.transform_utils import euler2mat, mat2quat
+
+REGISTERED_ROBOTS = {}
+
+
+def register_robot(target_class):
+ REGISTERED_ROBOTS[target_class.__name__] = target_class
+
+
+def create_robot(robot_name, *args, **kwargs):
+ """
+ Instantiates a Robot object.
+
+ Args:
+ robot_name (str): Name of the robot to initialize
+ *args: Additional arguments to pass to the specific Robot class initializer
+ **kwargs: Additional arguments to pass to the specific Robot class initializer
+
+ Returns:
+ Robot: Desired robot
+
+ Raises:
+ Exception: [Invalid robot name]
+ """
+ if robot_name not in REGISTERED_ROBOTS:
+ raise Exception(
+ "Robot {} not found. Make sure it is a registered robot among: {}".format(
+ robot_name, ", ".join(REGISTERED_ROBOTS)
+ )
+ )
+ return REGISTERED_ROBOTS[robot_name](*args, **kwargs)
+
+
+class RobotModelMeta(type):
+ """Metaclass for registering robot arms"""
+
+ def __new__(meta, name, bases, class_dict):
+ cls = super().__new__(meta, name, bases, class_dict)
+
+ # List all environments that should not be registered here.
+ _unregistered_envs = ["RobotModel", "ManipulatorModel"]
+
+ if cls.__name__ not in _unregistered_envs:
+ register_robot(cls)
+ return cls
+
+
+class RobotModel(MujocoXMLModel, metaclass=RobotModelMeta):
+ """
+ Base class for all robot models.
+
+ Args:
+ fname (str): Path to relevant xml file from which to create this robot instance
+ idn (int or str): Number or some other unique identification string for this robot instance
+ """
+
+ def __init__(self, fname, idn=0):
+ super().__init__(fname, idn=idn)
+
+ # Define other variables that get filled later
+ self.mount = None
+
+ # Get camera names for this robot
+ self.cameras = self.get_element_names(self.worldbody, "camera")
+
+ # By default, set small frictionloss and armature values
+ self.set_joint_attribute(attrib="frictionloss", values=0.1 * np.ones(self.dof), force=False)
+ self.set_joint_attribute(attrib="damping", values=0.1 * np.ones(self.dof), force=False)
+ self.set_joint_attribute(
+ attrib="armature", values=np.array([5.0 / (i + 1) for i in range(self.dof)]), force=False
+ )
+
+ def set_base_xpos(self, pos):
+ """
+ Places the robot on position @pos.
+
+ Args:
+ pos (3-array): (x,y,z) position to place robot base
+ """
+ self._elements["root_body"].set("pos", array_to_string(pos - self.bottom_offset))
+
+ def set_base_ori(self, rot):
+ """
+ Rotates robot by rotation @rot from its original orientation.
+
+ Args:
+ rot (3-array): (r,p,y) euler angles specifying the orientation for the robot base
+ """
+ # xml quat assumes w,x,y,z so we need to convert to this format from outputted x,y,z,w format from fcn
+ rot = mat2quat(euler2mat(rot))[[3, 0, 1, 2]]
+ self._elements["root_body"].set("quat", array_to_string(rot))
+
+ def set_joint_attribute(self, attrib, values, force=True):
+ """
+ Sets joint attributes, e.g.: friction loss, damping, etc.
+
+ Args:
+ attrib (str): Attribute to set for all joints
+ values (n-array): Values to set for each joint
+ force (bool): If True, will automatically override any pre-existing value. Otherwise, if a value already
+ exists for this value, it will be skipped
+
+ Raises:
+ AssertionError: [Inconsistent dimension sizes]
+ """
+ assert values.size == len(self._elements["joints"]), (
+ "Error setting joint attributes: "
+ + "Values must be same size as joint dimension. Got {}, expected {}!".format(values.size, self.dof)
+ )
+ for i, joint in enumerate(self._elements["joints"]):
+ if force or joint.get(attrib, None) is None:
+ joint.set(attrib, array_to_string(np.array([values[i]])))
+
+ def add_mount(self, mount):
+ """
+ Mounts @mount to arm.
+
+ Throws error if robot already has a mount or if mount type is incorrect.
+
+ Args:
+ mount (MountModel): mount MJCF model
+
+ Raises:
+ ValueError: [mount already added]
+ """
+ if self.mount is not None:
+ raise ValueError("Mount already added for this robot!")
+
+ # First adjust mount's base position
+ offset = self.base_offset - mount.top_offset
+ mount._elements["root_body"].set("pos", array_to_string(offset))
+
+ self.merge(mount, merge_body=self.root_body)
+
+ self.mount = mount
+
+ # Update cameras in this model
+ self.cameras = self.get_element_names(self.worldbody, "camera")
+
+ # -------------------------------------------------------------------------------------- #
+ # Public Properties: In general, these are the name-adjusted versions from the private #
+ # attributes pulled from their respective raw xml files #
+ # -------------------------------------------------------------------------------------- #
+
+ @property
+ def naming_prefix(self):
+ return "robot{}_".format(self.idn)
+
+ @property
+ def dof(self):
+ """
+ Defines the number of DOF of the robot
+
+ Returns:
+ int: robot DOF
+ """
+ return len(self._joints)
+
+ @property
+ def bottom_offset(self):
+ """
+ Returns vector from model root body to model bottom.
+ By default, this is equivalent to this robot's mount's (bottom_offset - top_offset) + this robot's base offset
+
+ Returns:
+ np.array: (dx, dy, dz) offset vector
+ """
+ return (
+ (self.mount.bottom_offset - self.mount.top_offset) + self._base_offset
+ if self.mount is not None
+ else self._base_offset
+ )
+
+ @property
+ def horizontal_radius(self):
+ """
+ Returns maximum distance from model root body to any radial point of the model. This method takes into
+ account the mount horizontal radius as well
+
+ Returns:
+ float: radius
+ """
+ return max(self._horizontal_radius, self.mount.horizontal_radius)
+
+ @property
+ def models(self):
+ """
+ Returns a list of all m(sub-)models owned by this robot model. By default, this includes the mount model,
+ if specified
+
+ Returns:
+ list: models owned by this object
+ """
+ return [self.mount] if self.mount is not None else []
+
+ @property
+ def contact_geom_rgba(self):
+ return ROBOT_COLLISION_COLOR
+
+ # -------------------------------------------------------------------------------------- #
+ # All subclasses must implement the following properties #
+ # -------------------------------------------------------------------------------------- #
+
+ @property
+ def default_mount(self):
+ """
+ Defines the default mount type for this robot that gets added to root body (base)
+
+ Returns:
+ str: Default mount name to add to this robot
+ """
+ raise NotImplementedError
+
+ @property
+ def default_controller_config(self):
+ """
+ Defines the name of default controller config file in the controllers/config directory for this robot.
+
+ Returns:
+ str: filename of default controller config for this robot
+ """
+ raise NotImplementedError
+
+ @property
+ def init_qpos(self):
+ """
+ Defines the default rest qpos of this robot
+
+ Returns:
+ np.array: Default init qpos of this robot
+ """
+ raise NotImplementedError
+
+ @property
+ def base_xpos_offset(self):
+ """
+ Defines the dict of various (x,y,z) tuple offsets relative to specific arenas placed at (0,0,0)
+ Assumes robot is facing forwards (in the +x direction) when determining offset. Should have entries for each
+ arena case; i.e.: "bins", "empty", and "table")
+
+ Returns:
+ dict: Dict mapping arena names to robot offsets from the global origin (dict entries may also be lambdas
+ for variable offsets)
+ """
+ raise NotImplementedError
+
+ @property
+ def top_offset(self):
+ """
+ Returns vector from model root body to model top.
+ Useful for, e.g. placing models on a surface.
+ Must be defined by subclass.
+
+ Returns:
+ np.array: (dx, dy, dz) offset vector
+ """
+ raise NotImplementedError
+
+ @property
+ def _horizontal_radius(self):
+ """
+ Returns maximum distance from model root body to any radial point of the model.
+
+ Helps us put models programmatically without them flying away due to a huge initial contact force.
+ Must be defined by subclass.
+
+ Returns:
+ float: radius
+ """
+ raise NotImplementedError
+
+ @property
+ def _important_sites(self):
+ """
+ Returns:
+ dict: (Default is no important sites; i.e.: empty dict)
+ """
+ return {}
+
+ @property
+ def _important_geoms(self):
+ """
+ Returns:
+ dict: (Default is no important geoms; i.e.: empty dict)
+ """
+ return {}
+
+ @property
+ def _important_sensors(self):
+ """
+ Returns:
+ dict: (Default is no sensors; i.e.: empty dict)
+ """
+ return {}
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/tasks/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/models/tasks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f92c679f3dc810b4870bcc4c30b185299c2da5fd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/tasks/__init__.py
@@ -0,0 +1,2 @@
+from .task import Task
+from .manipulation_task import ManipulationTask
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/tasks/manipulation_task.py b/phantom/submodules/phantom-robosuite/robosuite/models/tasks/manipulation_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..8309e2420d1145be40b9c3872743f2939d6b061a
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/tasks/manipulation_task.py
@@ -0,0 +1,7 @@
+from robosuite.models.tasks.task import Task
+
+
+class ManipulationTask(Task):
+ """
+ A manipulation-specific task. This is currently a future-proofing placeholder.
+ """
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/tasks/task.py b/phantom/submodules/phantom-robosuite/robosuite/models/tasks/task.py
new file mode 100644
index 0000000000000000000000000000000000000000..658fc2f3fd463d8a3bde4ffd1f7496c3070ab7bc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/tasks/task.py
@@ -0,0 +1,191 @@
+from copy import deepcopy
+
+from robosuite.models.objects import MujocoObject
+from robosuite.models.robots import RobotModel
+from robosuite.models.world import MujocoWorldBase
+from robosuite.utils.mjcf_utils import get_ids
+
+
+class Task(MujocoWorldBase):
+ """
+ Creates MJCF model for a task performed.
+
+ A task consists of one or more robots interacting with a variable number of
+ objects. This class combines the robot(s), the arena, and the objects
+ into a single MJCF model.
+
+ Args:
+ mujoco_arena (Arena): MJCF model of robot workspace
+
+ mujoco_robots (RobotModel or list of RobotModel): MJCF model of robot model(s) (list)
+
+ mujoco_objects (None or MujocoObject or list of MujocoObject): a list of MJCF models of physical objects
+
+ Raises:
+ AssertionError: [Invalid input object type]
+ """
+
+ def __init__(
+ self,
+ mujoco_arena,
+ mujoco_robots,
+ mujoco_objects=None,
+ ):
+ super().__init__()
+
+ # Store references to all models
+ self.mujoco_arena = mujoco_arena
+ self.mujoco_robots = [mujoco_robots] if isinstance(mujoco_robots, RobotModel) else mujoco_robots
+ if mujoco_objects is None:
+ self.mujoco_objects = []
+ else:
+ self.mujoco_objects = [mujoco_objects] if isinstance(mujoco_objects, MujocoObject) else mujoco_objects
+
+ # Merge all models
+ self.merge_arena(self.mujoco_arena)
+ for mujoco_robot in self.mujoco_robots:
+ self.merge_robot(mujoco_robot)
+ self.merge_objects(self.mujoco_objects)
+
+ self._instances_to_ids = None
+ self._geom_ids_to_instances = None
+ self._site_ids_to_instances = None
+ self._classes_to_ids = None
+ self._geom_ids_to_classes = None
+ self._site_ids_to_classes = None
+
+ def merge_robot(self, mujoco_robot):
+ """
+ Adds robot model to the MJCF model.
+
+ Args:
+ mujoco_robot (RobotModel): robot to merge into this MJCF model
+ """
+ self.merge(mujoco_robot)
+
+ def merge_arena(self, mujoco_arena):
+ """
+ Adds arena model to the MJCF model.
+
+ Args:
+ mujoco_arena (Arena): arena to merge into this MJCF model
+ """
+ self.merge(mujoco_arena)
+
+ def merge_objects(self, mujoco_objects):
+ """
+ Adds object models to the MJCF model.
+
+ Args:
+ mujoco_objects (list of MujocoObject): objects to merge into this MJCF model
+ """
+ for mujoco_obj in mujoco_objects:
+ # Make sure we actually got a MujocoObject
+ assert isinstance(mujoco_obj, MujocoObject), "Tried to merge non-MujocoObject! Got type: {}".format(
+ type(mujoco_obj)
+ )
+ # Merge this object
+ self.merge_assets(mujoco_obj)
+ self.worldbody.append(mujoco_obj.get_obj())
+
+ def generate_id_mappings(self, sim):
+ """
+ Generates IDs mapping class instances to set of (visual) geom IDs corresponding to that class instance
+
+ Args:
+ sim (MjSim): Current active mujoco simulation object
+ """
+ self._instances_to_ids = {}
+ self._geom_ids_to_instances = {}
+ self._site_ids_to_instances = {}
+ self._classes_to_ids = {}
+ self._geom_ids_to_classes = {}
+ self._site_ids_to_classes = {}
+
+ models = [model for model in self.mujoco_objects]
+ for robot in self.mujoco_robots:
+ models += [robot] + robot.models
+
+ # Parse all mujoco models from robots and objects
+ for model in models:
+ # Grab model class name and visual IDs
+ cls = str(type(model)).split("'")[1].split(".")[-1]
+ inst = model.name
+ id_groups = [
+ get_ids(sim=sim, elements=model.visual_geoms + model.contact_geoms, element_type="geom"),
+ get_ids(sim=sim, elements=model.sites, element_type="site"),
+ ]
+ group_types = ("geom", "site")
+ ids_to_instances = (self._geom_ids_to_instances, self._site_ids_to_instances)
+ ids_to_classes = (self._geom_ids_to_classes, self._site_ids_to_classes)
+
+ # Add entry to mapping dicts
+
+ # Instances should be unique
+ assert inst not in self._instances_to_ids, f"Instance {inst} already registered; should be unique"
+ self._instances_to_ids[inst] = {}
+
+ # Classes may not be unique
+ if cls not in self._classes_to_ids:
+ self._classes_to_ids[cls] = {group_type: [] for group_type in group_types}
+
+ for ids, group_type, ids_to_inst, ids_to_cls in zip(
+ id_groups, group_types, ids_to_instances, ids_to_classes
+ ):
+ # Add geom, site ids
+ self._instances_to_ids[inst][group_type] = ids
+ self._classes_to_ids[cls][group_type] += ids
+
+ # Add reverse mappings as well
+ for idn in ids:
+ assert idn not in ids_to_inst, f"ID {idn} already registered; should be unique"
+ ids_to_inst[idn] = inst
+ ids_to_cls[idn] = cls
+
+ @property
+ def geom_ids_to_instances(self):
+ """
+ Returns:
+ dict: Mapping from geom IDs in sim to specific class instance names
+ """
+ return deepcopy(self._geom_ids_to_instances)
+
+ @property
+ def site_ids_to_instances(self):
+ """
+ Returns:
+ dict: Mapping from site IDs in sim to specific class instance names
+ """
+ return deepcopy(self._site_ids_to_instances)
+
+ @property
+ def instances_to_ids(self):
+ """
+ Returns:
+ dict: Mapping from specific class instance names to {geom, site} IDs in sim
+ """
+ return deepcopy(self._instances_to_ids)
+
+ @property
+ def geom_ids_to_classes(self):
+ """
+ Returns:
+ dict: Mapping from geom IDs in sim to specific classes
+ """
+ return deepcopy(self._geom_ids_to_classes)
+
+ @property
+ def site_ids_to_classes(self):
+ """
+ Returns:
+ dict: Mapping from site IDs in sim to specific classes
+ """
+ return deepcopy(self._site_ids_to_classes)
+
+ @property
+ def classes_to_ids(self):
+ """
+ Returns:
+ dict: Mapping from specific classes to {geom, site} IDs in sim
+ """
+ return deepcopy(self._classes_to_ids)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/models/world.py b/phantom/submodules/phantom-robosuite/robosuite/models/world.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee2f265d374dd0e9a7e2586176c0758a91f015dc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/models/world.py
@@ -0,0 +1,13 @@
+import robosuite.macros as macros
+from robosuite.models.base import MujocoXML
+from robosuite.utils.mjcf_utils import convert_to_string, find_elements, xml_path_completion
+
+
+class MujocoWorldBase(MujocoXML):
+ """Base class to inherit all mujoco worlds from."""
+
+ def __init__(self):
+ super().__init__(xml_path_completion("base.xml"))
+ # Modify the simulation timestep to be the requested value
+ options = find_elements(root=self.root, tags="option", attribs=None, return_first=True)
+ options.set("timestep", convert_to_string(macros.SIMULATION_TIMESTEP))
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac7004e9d49e6122425170c97027b242029305c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/__init__.py
@@ -0,0 +1 @@
+from .base import load_renderer_config
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/base.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeb60c3dde1523b6f89352614027c4e66cecf696
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/base.py
@@ -0,0 +1,78 @@
+"""
+This file contains the base renderer class for Mujoco environments.
+"""
+
+import abc
+import json
+import os
+
+
+def load_renderer_config(renderer):
+ """Loads the config of the specified renderer.
+ Modify the dictionary returned by this function
+ according to reuirements.
+
+ Args:
+ renderer (str): Name of the renderer to use.
+
+ Returns:
+ dict: renderer default config.
+ """
+ if renderer == "nvisii":
+ fname = "config/nvisii_config.json"
+ else:
+ raise ValueError(f"renderer type can only be 'nvisii' got '{renderer}'")
+
+ dir_path = os.path.dirname(__file__)
+ with open(os.path.join(dir_path, fname)) as f:
+ config = json.load(f)
+
+ return config
+
+
+class Renderer:
+ """
+ Base class for all robosuite renderers
+ Defines basic interface for all renderers to adhere to
+ """
+
+ def __init__(self, env, renderer_type="mujoco"):
+ self.env = env
+ self.renderer_type = renderer_type
+
+ def __str__(self):
+ """Prints the renderer type in a formatted way
+
+ Returns:
+ str: string representing the renderer
+ """
+ return f''
+
+ @abc.abstractmethod
+ def render(self, **kwargs):
+ """Renders the current state with the specified renderer"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def update(self):
+ """Updates the states in the renderer (for NVISII)"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def close(self):
+ """Closes the renderer objects"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def reset(self):
+ """Reset the renderer with initial states for environment"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def get_pixel_obs(self):
+ """Get the pixel observations from the given renderer
+
+ Returns:
+ numpyarr: numpy array representing pixels of renderer
+ """
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/base_parser.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/base_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e0ff7aed6f533b60a3aa2f8914cb3fa5db1bf04
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/base_parser.py
@@ -0,0 +1,55 @@
+import abc
+import xml.etree.ElementTree as ET
+
+
+class BaseParser(object):
+ """
+ Base class for Parser objects used by renderers.
+ """
+
+ def __init__(self, renderer, env):
+ """
+ Parse the mujoco xml and initialize iG renderer objects.
+
+ Args:
+ renderer: the renderer
+ env : Mujoco env
+ """
+
+ self.renderer = renderer
+ self.env = env
+ self.xml_root = ET.fromstring(self.env.sim.model.get_xml())
+ self.parent_map = {c: p for p in self.xml_root.iter() for c in p}
+ self.visual_objects = {}
+
+ @abc.abstractmethod
+ def parse_textures(self):
+ """
+ Parse and load all textures and store them
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def parse_materials(self):
+ """
+ Parse all materials and use texture mapping to initialize materials
+ """
+ raise NotImplementedError
+
+ def parse_cameras(self):
+ """
+ Parse cameras and initialize the cameras.
+ """
+ raise NotImplementedError
+
+ def parse_meshes(self):
+ """
+ Create mapping of meshes.
+ """
+ raise NotImplementedError
+
+ def parse_geometries(self):
+ """
+ Iterate through each geometry and load it in the renderer.
+ """
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/config/nvisii_config.json b/phantom/submodules/phantom-robosuite/robosuite/renderers/config/nvisii_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..5d6f5d862ada51716402a7d22b9780b92fdd6c7c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/config/nvisii_config.json
@@ -0,0 +1,14 @@
+{
+ "img_path": "images/",
+ "width": 1280,
+ "height": 720,
+ "spp": 512,
+ "use_noise": false,
+ "debug_mode": false,
+ "video_mode": false,
+ "video_path": "videos/",
+ "video_name": "robosuite_video_0.mp4",
+ "video_fps": 30,
+ "verbose": 1,
+ "vision_modalities": null
+}
\ No newline at end of file
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/context/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/context/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/context/egl_context.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/context/egl_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bfc15c18aca6e96ae4ad9f6fba0c45642899776
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/context/egl_context.py
@@ -0,0 +1,155 @@
+# Modifications Copyright 2022 The robosuite Authors
+# Original Copyright 2018 The dm_control Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import atexit
+import ctypes
+import os
+
+PYOPENGL_PLATFORM = os.environ.get("PYOPENGL_PLATFORM")
+
+if not PYOPENGL_PLATFORM:
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
+elif PYOPENGL_PLATFORM.lower() != "egl":
+ raise ImportError(
+ "Cannot use EGL rendering platform. "
+ "The PYOPENGL_PLATFORM environment variable is set to {!r} "
+ "(should be either unset or 'egl')."
+ )
+
+from mujoco.egl import egl_ext as EGL
+from OpenGL import error
+
+
+def create_initialized_egl_device_display(device_id=0):
+ """Creates an initialized EGL display directly on a device."""
+ all_devices = EGL.eglQueryDevicesEXT()
+ selected_device = (
+ os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if os.environ.get("MUJOCO_EGL_DEVICE_ID", None) is None
+ else os.environ.get("MUJOCO_EGL_DEVICE_ID", None)
+ )
+ if selected_device is None:
+ candidates = all_devices
+ if device_id == -1:
+ device_idx = 0
+ else:
+ device_idx = device_id
+ else:
+ if not selected_device.isdigit():
+ device_inds = [int(x) for x in selected_device.split(",")]
+ if device_id == -1:
+ device_idx = device_inds[0]
+ else:
+ assert device_id in device_inds, "specified device id is not made visible in environment variables."
+ device_idx = device_id
+ else:
+ device_idx = int(selected_device)
+ if not 0 <= device_idx < len(all_devices):
+ raise RuntimeError(
+ f"The MUJOCO_EGL_DEVICE_ID environment variable must be an integer "
+ f"between 0 and {len(all_devices)-1} (inclusive), got {device_idx}."
+ )
+ candidates = all_devices[device_idx : device_idx + 1]
+ for device in candidates:
+ display = EGL.eglGetPlatformDisplayEXT(EGL.EGL_PLATFORM_DEVICE_EXT, device, None)
+ if display != EGL.EGL_NO_DISPLAY and EGL.eglGetError() == EGL.EGL_SUCCESS:
+ # `eglInitialize` may or may not raise an exception on failure depending
+ # on how PyOpenGL is configured. We therefore catch a `GLError` and also
+ # manually check the output of `eglGetError()` here.
+ try:
+ initialized = EGL.eglInitialize(display, None, None)
+ except error.GLError:
+ pass
+ else:
+ if initialized == EGL.EGL_TRUE and EGL.eglGetError() == EGL.EGL_SUCCESS:
+ return display
+ return EGL.EGL_NO_DISPLAY
+
+
+global EGL_DISPLAY
+EGL_DISPLAY = None
+
+EGL_ATTRIBUTES = (
+ EGL.EGL_RED_SIZE,
+ 8,
+ EGL.EGL_GREEN_SIZE,
+ 8,
+ EGL.EGL_BLUE_SIZE,
+ 8,
+ EGL.EGL_ALPHA_SIZE,
+ 8,
+ EGL.EGL_DEPTH_SIZE,
+ 24,
+ EGL.EGL_STENCIL_SIZE,
+ 8,
+ EGL.EGL_COLOR_BUFFER_TYPE,
+ EGL.EGL_RGB_BUFFER,
+ EGL.EGL_SURFACE_TYPE,
+ EGL.EGL_PBUFFER_BIT,
+ EGL.EGL_RENDERABLE_TYPE,
+ EGL.EGL_OPENGL_BIT,
+ EGL.EGL_NONE,
+)
+
+
+class EGLGLContext:
+ """An EGL context for headless accelerated OpenGL rendering on GPU devices."""
+
+ def __init__(self, max_width, max_height, device_id=0):
+
+ del max_width, max_height # unused
+ num_configs = ctypes.c_long()
+ config_size = 1
+ config = EGL.EGLConfig()
+ EGL.eglReleaseThread()
+ global EGL_DISPLAY
+ if EGL_DISPLAY is None:
+ # only initialize for the first time
+ EGL_DISPLAY = create_initialized_egl_device_display(device_id=device_id)
+ if EGL_DISPLAY == EGL.EGL_NO_DISPLAY:
+ raise ImportError(
+ "Cannot initialize a EGL device display. This likely means that your EGL "
+ "driver does not support the PLATFORM_DEVICE extension, which is "
+ "required for creating a headless rendering context."
+ )
+ atexit.register(EGL.eglTerminate, EGL_DISPLAY)
+ EGL.eglChooseConfig(EGL_DISPLAY, EGL_ATTRIBUTES, ctypes.byref(config), config_size, num_configs)
+ if num_configs.value < 1:
+ raise RuntimeError(
+ "EGL failed to find a framebuffer configuration that matches the "
+ "desired attributes: {}".format(EGL_ATTRIBUTES)
+ )
+ EGL.eglBindAPI(EGL.EGL_OPENGL_API)
+ self._context = EGL.eglCreateContext(EGL_DISPLAY, config, EGL.EGL_NO_CONTEXT, None)
+ if not self._context:
+ raise RuntimeError("Cannot create an EGL context.")
+
+ def make_current(self):
+ if not EGL.eglMakeCurrent(EGL_DISPLAY, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, self._context):
+ raise RuntimeError("Failed to make the EGL context current.")
+
+ def free(self):
+ """Frees resources associated with this context."""
+ if self._context:
+ current_context = EGL.eglGetCurrentContext()
+ if current_context and self._context.address == current_context.address:
+ EGL.eglMakeCurrent(EGL_DISPLAY, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
+ EGL.eglDestroyContext(EGL_DISPLAY, self._context)
+ EGL.eglReleaseThread()
+ self._context = None
+
+ def __del__(self):
+ self.free()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/context/glfw_context.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/context/glfw_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..087fe0d496833c489bc4debb652a08228cc7b0bf
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/context/glfw_context.py
@@ -0,0 +1,24 @@
+# Copyright 2017 The dm_control Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An OpenGL context created via GLFW."""
+
+from mujoco.glfw import GLContext
+
+
+class GLFWGLContext(GLContext):
+ """An OpenGL context created via GLFW."""
+
+ def __init__(self, max_width, max_height, device_id=0):
+ super().__init__(max_width, max_height)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/context/osmesa_context.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/context/osmesa_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2c8918cc270a7853cb5af7df15ac8ea00019031
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/context/osmesa_context.py
@@ -0,0 +1,26 @@
+# Copyright 2018 The dm_control Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An OSMesa context for software-based OpenGL rendering."""
+
+import os
+
+from mujoco.osmesa import GLContext
+
+
+class OSMesaGLContext(GLContext):
+ """An OSMesa context for software-based OpenGL rendering."""
+
+ def __init__(self, max_width, max_height, device_id=-1):
+ super().__init__(max_width, max_height)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/nvisii_renderer.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/nvisii_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..787eb7b8ef9a08bf5df09fa02c1880c38c4b3c96
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/nvisii_renderer.py
@@ -0,0 +1,575 @@
+import colorsys
+import os
+
+import cv2
+import matplotlib.cm as cm
+import numpy as np
+import nvisii
+import open3d as o3d
+
+import robosuite as suite
+import robosuite.renderers.nvisii.nvisii_utils as utils
+from robosuite.renderers.base import Renderer
+from robosuite.renderers.nvisii.parser import Parser
+from robosuite.utils import transform_utils as T
+from robosuite.utils.mjcf_utils import xml_path_completion
+from robosuite.utils.transform_utils import mat2quat
+from robosuite.wrappers import Wrapper
+
+np.set_printoptions(threshold=np.inf)
+
+
+class NVISIIRenderer(Renderer):
+ def __init__(
+ self,
+ env,
+ img_path="images/",
+ width=500,
+ height=500,
+ spp=256,
+ use_noise=False,
+ debug_mode=False,
+ video_mode=False,
+ video_path="videos/",
+ video_name="robosuite_video_0.mp4",
+ video_fps=60,
+ verbose=1,
+ vision_modalities=None,
+ ):
+ """
+ Initializes the nvisii wrapper. Wrapping any MuJoCo environment in this
+ wrapper will use the NVISII wrapper for rendering.
+
+ Args:
+ env (MujocoEnv instance): The environment to wrap.
+
+ img_path (string): Path to images.
+
+ width (int, optional): Width of the rendered image. Defaults to 500.
+
+ height (int, optional): Height of the rendered image. Defaults to 500.
+
+ spp (int, optional): Sample-per-pixel for each image. Larger spp will result
+ in higher quality images but will take more time to render
+ each image. Higher quality images typically use an spp of
+ around 512.
+
+ use_noise (bool, optional): Use noise or denoise. Deafults to false.
+
+ debug_mode (bool, optional): Use debug mode for nvisii. Deafults to false.
+
+ video_mode (bool, optional): By deafult, the NVISII wrapper saves the results as
+ images. If video_mode is set to true, a video is
+ produced and will be stored in the directory defined
+ by video_path. Defaults to false.
+
+ video_path (string, optional): Path to store the video. Required if video_mode is
+ set to true. Defaults to 'videos/'.
+
+ video_name (string, optional): Name for the file for the video. Defaults to
+ 'robosuite_video_0.mp4'.
+
+ video_fps (int, optional): Frames per second for video. Defaults to 60.
+
+ verbose (int, optional): If verbose is set to 1, the wrapper will print the image
+ number for each image rendered. If verbose is set to 0,
+ nothing will be printed. Defaults to 1.
+
+ vision_modalities (string, optional): Options to render image with different ground truths
+ for NVISII. Options include "normal", "texture_coordinates",
+ "position", "depth".
+ """
+
+ super().__init__(env, renderer_type="nvisii")
+
+ self.env = env
+ self.img_path = img_path
+ self.width = width
+ self.height = height
+ self.spp = spp
+ self.use_noise = use_noise
+
+ self.video_mode = video_mode
+ self.video_path = video_path
+ self.video_name = video_name
+ self.video_fps = video_fps
+
+ self.verbose = verbose
+ self.vision_modalities = vision_modalities
+
+ self.img_cntr = 0
+
+ env._setup_references()
+
+ # enable interactive mode when debugging
+ if debug_mode:
+ nvisii.initialize_interactive()
+ else:
+ nvisii.initialize(headless=True)
+
+ self.segmentation_type = self.env.camera_segmentations
+
+ # add denoiser to nvisii if not using noise
+ if not use_noise:
+ nvisii.configure_denoiser()
+ nvisii.enable_denoiser()
+ nvisii.configure_denoiser(True, True, False)
+
+ if not os.path.exists(img_path):
+ os.makedirs(img_path)
+
+ if video_mode:
+ if not os.path.exists(video_path):
+ os.makedirs(video_path)
+ self.video = cv2.VideoWriter(
+ video_path + video_name, cv2.VideoWriter_fourcc(*"MP4V"), video_fps, (self.width, self.height)
+ )
+ print(f"video mode enabled")
+
+ if vision_modalities is None and self.segmentation_type[0] == None:
+ nvisii.sample_pixel_area(x_sample_interval=(0.0, 1.0), y_sample_interval=(0.0, 1.0))
+ else:
+ nvisii.sample_pixel_area(x_sample_interval=(0.5, 0.5), y_sample_interval=(0.5, 0.5))
+
+ self._init_nvisii_components()
+
+ def _init_nvisii_components(self):
+ self._init_lighting()
+ self._init_floor(image="plywood-4k.jpg")
+ self._init_walls(image="plaster-wall-4k.jpg")
+ self._init_camera()
+
+ self._load()
+
+ def _init_lighting(self):
+ # Intiailizes the lighting
+ self.light_1 = nvisii.entity.create(
+ name="light",
+ mesh=nvisii.mesh.create_sphere("light"),
+ transform=nvisii.transform.create("light"),
+ )
+
+ self.light_1.set_light(nvisii.light.create("light"))
+
+ self.light_1.get_light().set_intensity(150) # intensity of the light
+ self.light_1.get_transform().set_scale(nvisii.vec3(0.3)) # scale the light down
+ self.light_1.get_transform().set_position(nvisii.vec3(3, 3, 4)) # sets the position of the light
+
+ def _init_floor(self, image):
+ """
+ Intiailizes the floor
+
+ Args:
+ image (string): String for the file to use as an image for the floor
+
+ """
+ floor_mesh = nvisii.mesh.create_plane(name="plane", size=nvisii.vec2(3, 3))
+
+ floor_entity = nvisii.entity.create(
+ name="floor",
+ mesh=floor_mesh,
+ material=nvisii.material.create("plane"),
+ transform=nvisii.transform.create("plane"),
+ )
+ floor_entity.get_transform().set_scale(nvisii.vec3(1))
+ floor_entity.get_transform().set_position(nvisii.vec3(0, 0, 0))
+
+ texture_image = xml_path_completion("textures/" + image)
+ texture = nvisii.texture.create_from_file(name="floor_texture", path=texture_image)
+
+ floor_entity.get_material().set_base_color_texture(texture)
+ floor_entity.get_material().set_roughness(0.4)
+ floor_entity.get_material().set_specular(0)
+
+ def _init_walls(self, image):
+ """
+ Intiailizes the walls
+
+ Args:
+ image (string): String for the file to use as an image for the walls
+ """
+ texture_image = xml_path_completion("textures/" + image)
+ texture = nvisii.texture.create_from_file(name="wall_texture", path=texture_image)
+
+ for wall in self.env.model.mujoco_arena.worldbody.findall("./geom[@material='walls_mat']"):
+
+ name = wall.get("name")
+ size = [float(x) for x in wall.get("size").split(" ")]
+
+ pos, quat = self._get_orientation_geom(name)
+
+ wall_entity = nvisii.entity.create(
+ name=name,
+ mesh=nvisii.mesh.create_box(name=name, size=nvisii.vec3(size[0], size[1], size[2])),
+ transform=nvisii.transform.create(name),
+ material=nvisii.material.create(name),
+ )
+
+ wall_entity.get_transform().set_position(nvisii.vec3(pos[0], pos[1], pos[2]))
+
+ wall_entity.get_transform().set_rotation(nvisii.quat(quat[0], quat[1], quat[2], quat[3]))
+
+ wall_entity.get_material().set_base_color_texture(texture)
+
+ def _init_camera(self):
+ """
+ Intializes the camera for the NVISII renderer
+ """
+
+ # intializes the camera
+ self.camera = nvisii.entity.create(
+ name="camera",
+ transform=nvisii.transform.create("camera_transform"),
+ )
+
+ self.camera.set_camera(
+ nvisii.camera.create_from_fov(
+ name="camera_camera", field_of_view=1, aspect=float(self.width) / float(self.height)
+ )
+ )
+
+ # Sets the primary camera of the renderer to the camera entity
+ nvisii.set_camera_entity(self.camera)
+ self._camera_configuration(
+ at_vec=nvisii.vec3(0, 0, 1.06),
+ up_vec=nvisii.vec3(0, 0, 1),
+ eye_vec=nvisii.vec3(1.24, 0.0, 1.35),
+ quat=nvisii.quat(-1, 0, 0, 0),
+ )
+
+ # Environment configuration
+ self._dome_light_intensity = 1
+ nvisii.set_dome_light_intensity(self._dome_light_intensity)
+ nvisii.set_max_bounce_depth(4)
+
+ def _camera_configuration(self, at_vec, up_vec, eye_vec, quat):
+ """
+ Sets the configuration for the NVISII camera. Configuration
+ is dependent on where the camera is located and where it
+ looks at
+ """
+ # configures the camera
+ self.camera.get_transform().look_at(
+ at=at_vec, up=up_vec, eye=eye_vec, previous=False # look at (world coordinate) # up vector
+ )
+
+ self.camera.get_transform().rotate_around(eye_vec, quat)
+
+ def set_camera_pos_quat(self, pos, quat):
+ self.camera.get_transform().set_position(pos)
+ self.camera.get_transform().look_at(
+ at=(0, 0, 1.06), up=(0, 0, 1), eye=pos, previous=False # look at (world coordinate) # up vector
+ )
+ # self.camera.get_transform().rotate_around(pos, quat)
+
+ def _get_orientation_geom(self, name):
+ """
+ Gets the position and quaternion for a geom
+ """
+
+ pos = self.env.sim.data.geom_xpos[self.env.sim.model.geom_name2id(name)]
+ R = self.env.sim.data.geom_xmat[self.env.sim.model.geom_name2id(name)].reshape(3, 3)
+
+ quat_xyzw = mat2quat(R)
+ quat = np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]])
+
+ return pos, quat
+
+ def _load(self):
+ """
+ Loads the nessecary textures, materials, and geoms into the
+ NVISII renderer
+ """
+ self.parser = Parser("nvisii", self.env, self.segmentation_type)
+ self.parser.parse_textures()
+ self.parser.parse_materials()
+ self.parser.parse_geometries()
+ self.components = self.parser.components
+ self.max_elements = self.parser.max_elements
+ self.max_instances = self.parser.max_instances
+ self.max_classes = self.parser.max_classes
+
+ def update(self):
+ """
+ Updates the states for the wrapper given a certain action
+
+ Args:
+ action (np-array): The action the robot should take
+ """
+ for key, value in self.components.items():
+ self._update_orientation(name=key, component=value)
+
+ def _update_orientation(self, name, component):
+ """
+ Update position for an object or a robot in renderer.
+
+ Args:
+ name (string): name of component
+ component (nvisii entity or scene): Object in renderer and other info
+ for object.
+ """
+
+ obj = component.obj
+ parent_body_name = component.parent_body_name
+ geom_pos = component.geom_pos
+ geom_quat = component.geom_quat
+ dynamic = component.dynamic
+
+ if not dynamic:
+ return
+
+ self.body_tags = ["robot", "pedestal", "gripper", "peg"]
+
+ if parent_body_name != "worldbody":
+ if self.tag_in_name(name):
+ pos = self.env.sim.data.get_body_xpos(parent_body_name)
+ else:
+ pos = self.env.sim.data.get_geom_xpos(name)
+
+ B = self.env.sim.data.body_xmat[self.env.sim.model.body_name2id(parent_body_name)].reshape((3, 3))
+ quat_xyzw_body = mat2quat(B)
+ quat_wxyz_body = np.array(
+ [quat_xyzw_body[3], quat_xyzw_body[0], quat_xyzw_body[1], quat_xyzw_body[2]]
+ ) # wxyz
+ nvisii_quat = nvisii.quat(*quat_wxyz_body) * nvisii.quat(*geom_quat)
+
+ if self.tag_in_name(name):
+ # Add position offset if there are position offset defined in the geom tag
+ homo_mat = T.pose2mat((np.zeros((1, 3), dtype=np.float32), quat_xyzw_body))
+ pos_offset = homo_mat @ np.array([geom_pos[0], geom_pos[1], geom_pos[2], 1.0]).transpose()
+ pos = pos + pos_offset[:3]
+
+ else:
+ pos = [0, 0, 0]
+ nvisii_quat = nvisii.quat(1, 0, 0, 0) # wxyz
+
+ if isinstance(obj, nvisii.scene):
+
+ # temp fix -- look into XML file for correct quat
+ if "s_visual" in name:
+ # single robot
+ if len(self.env.robots) == 1:
+ nvisii_quat = nvisii.quat(0, 0.5, 0, 0)
+ # two robots - 0
+ elif len(self.env.robots) == 2 and "robot_0" in name:
+ nvisii_quat = nvisii.quat(-0, 0.5, 0.5, 0)
+ # two robots - 1
+ else:
+ nvisii_quat = nvisii.quat(-0, 0.5, -0.5, 0)
+
+ obj.transforms[0].set_position(nvisii.vec3(pos[0], pos[1], pos[2]))
+ obj.transforms[0].set_rotation(nvisii_quat)
+ else:
+ obj.get_transform().set_position(nvisii.vec3(pos[0], pos[1], pos[2]))
+ obj.get_transform().set_rotation(nvisii_quat)
+
+ def tag_in_name(self, name):
+ """
+ Checks if one of the tags in body tags in the name
+
+ Args:
+ name (string): Name of component
+ """
+ for tag in self.body_tags:
+ if tag in name:
+ return True
+ return False
+
+ def render(self, render_type="png"):
+ """
+ Renders an image of the NVISII renderer
+
+ Args:
+ render_type (string, optional): Type of file to save as. Defaults to 'png'
+ """
+
+ self.img_cntr += 1
+ verbose_word = "frame" if self.video_mode else "image"
+
+ if self.video_mode:
+ img_file = f"{self.img_path}/image_0.{render_type}"
+ if self.segmentation_type[0] != None:
+ self.render_segmentation_data(img_file)
+ elif self.vision_modalities is None:
+ self.render_to_file(img_file)
+ else:
+ self.render_data_to_file(img_file)
+
+ self.video.write(cv2.imread(img_file))
+ else:
+ img_file = f"{self.img_path}/image_{self.img_cntr}.{render_type}"
+ if self.segmentation_type[0] != None:
+ self.render_segmentation_data(img_file)
+ elif self.vision_modalities is None:
+ self.render_to_file(img_file)
+ else:
+ self.render_data_to_file(img_file)
+
+ if self.verbose == 1:
+ print(f"Rendering {verbose_word}... {self.img_cntr}")
+
+ def render_to_file(self, img_file):
+ nvisii.render_to_file(width=self.width, height=self.height, samples_per_pixel=self.spp, file_path=img_file)
+
+ def render_segmentation_data(self, img_file):
+
+ segmentation_array = nvisii.render_data(
+ width=int(self.width),
+ height=int(self.height),
+ start_frame=0,
+ frame_count=1,
+ bounce=int(0),
+ options="entity_id",
+ seed=1,
+ )
+ segmentation_array = np.array(segmentation_array).reshape(self.height, self.width, 4)[:, :, 0]
+ segmentation_array[segmentation_array > 3.4028234663852886e37] = 0
+ segmentation_array[segmentation_array < 3.4028234663852886e-37] = 0
+ segmentation_array = np.flipud(segmentation_array)
+
+ rgb_data = self.segmentation_to_rgb(segmentation_array.astype(dtype=np.uint8))
+
+ from PIL import Image
+
+ rgb_img = Image.fromarray(rgb_data)
+ rgb_img.save(img_file)
+
+ def render_data_to_file(self, img_file):
+
+ if self.vision_modalities == "depth" and self.img_cntr != 1:
+
+ depth_data = nvisii.render_data(
+ width=self.width,
+ height=self.height,
+ start_frame=0,
+ frame_count=1,
+ bounce=int(0),
+ options=self.vision_modalities,
+ )
+
+ depth_data = np.array(depth_data).reshape(self.height, self.width, 4)
+ depth_data = np.flipud(depth_data)[:, :, [0, 1, 2]]
+
+ # normalize depths
+ depth_data[:, :, 0] = (depth_data[:, :, 0] - np.min(depth_data[:, :, 0])) / (
+ np.max(depth_data[:, :, 0]) - np.min(depth_data[:, :, 0])
+ )
+ depth_data[:, :, 1] = (depth_data[:, :, 1] - np.min(depth_data[:, :, 1])) / (
+ np.max(depth_data[:, :, 1]) - np.min(depth_data[:, :, 1])
+ )
+ depth_data[:, :, 2] = (depth_data[:, :, 2] - np.min(depth_data[:, :, 2])) / (
+ np.max(depth_data[:, :, 2]) - np.min(depth_data[:, :, 2])
+ )
+
+ from PIL import Image
+
+ depth_image = Image.fromarray(((1 - depth_data) * 255).astype(np.uint8))
+ depth_image.save(img_file)
+
+ elif self.vision_modalities == "normal" and self.img_cntr != 1:
+
+ normal_data = nvisii.render_data(
+ width=self.width,
+ height=self.height,
+ start_frame=0,
+ frame_count=1,
+ bounce=int(0),
+ options="screen_space_normal",
+ )
+
+ normal_data = np.array(normal_data).reshape(self.height, self.width, 4)
+ normal_data = np.flipud(normal_data)[:, :, [0, 1, 2]]
+
+ normal_data[:, :, 0] = (normal_data[:, :, 0] + 1) / 2 * 255 # R
+ normal_data[:, :, 1] = (normal_data[:, :, 1] + 1) / 2 * 255 # G
+ normal_data[:, :, 2] = 255 - ((normal_data[:, :, 2] + 1) / 2 * 255) # B
+
+ from PIL import Image
+
+ normal_image = Image.fromarray((normal_data).astype(np.uint8))
+ normal_image.save(img_file)
+
+ else:
+
+ nvisii.render_data_to_file(
+ width=self.width,
+ height=self.height,
+ start_frame=0,
+ frame_count=1,
+ bounce=int(0),
+ options=self.vision_modalities,
+ file_path=img_file,
+ )
+
+ def randomize_colors(self, N, bright=True):
+ """
+ Modified from https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/visualize.py#L59
+ Generate random colors.
+ To get visually distinct colors, generate them in HSV space then
+ convert to RGB.
+ """
+ brightness = 1.0 if bright else 0.5
+ hsv = [(1.0 * i / N, 1, brightness) for i in range(N)]
+ colors = np.array(list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)))
+ rstate = np.random.RandomState(seed=20)
+ np.random.shuffle(colors)
+ return colors
+
+ def segmentation_to_rgb(self, seg_im, random_colors=False):
+ """
+ Helper function to visualize segmentations as RGB frames.
+ NOTE: assumes that geom IDs go up to 255 at most - if not,
+ multiple geoms might be assigned to the same color.
+ """
+ # ensure all values lie within [0, 255]
+ seg_im = np.mod(seg_im, 256)
+
+ if random_colors:
+ colors = self.randomize_colors(N=256, bright=True)
+ return (255.0 * colors[seg_im]).astype(np.uint8)
+ else:
+
+ cmap = cm.get_cmap("jet")
+
+ max_r = 0
+ if self.segmentation_type[0][0] == "element":
+ max_r = np.amax(seg_im) + 1
+ elif self.segmentation_type[0][0] == "class":
+ max_r = self.max_classes
+ for i in range(len(seg_im)):
+ for j in range(len(seg_im[0])):
+ if seg_im[i][j] in self.parser.entity_id_class_mapping:
+ seg_im[i][j] = self.parser.entity_id_class_mapping[seg_im[i][j]]
+ else:
+ seg_im[i][j] = max_r - 1
+ elif self.segmentation_type[0][0] == "instance":
+ max_r = self.max_instances
+ for i in range(len(seg_im)):
+ for j in range(len(seg_im[0])):
+ if seg_im[i][j] in self.parser.entity_id_class_mapping:
+ seg_im[i][j] = self.parser.entity_id_class_mapping[seg_im[i][j]]
+ else:
+ seg_im[i][j] = max_r - 1
+
+ color_list = np.array([cmap(i / (max_r)) for i in range(max_r)])
+
+ return (color_list[seg_im] * 255).astype(np.uint8)
+
+ def reset(self):
+ nvisii.clear_all()
+ self._init_nvisii_components()
+ self.update()
+
+ def get_pixel_obs(self):
+ frame_buffer = nvisii.render(width=self.width, height=self.height, samples_per_pixel=self.spp)
+
+ frame_buffer = np.array(frame_buffer).reshape(self.height, self.width, 4)
+ frame_buffer = np.flipud(frame_buffer)
+
+ return frame_buffer
+
+ def close(self):
+ """
+ Deinitializes the nvisii rendering environment
+ """
+ nvisii.deinitialize()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/nvisii_utils.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/nvisii_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc62a727a962ff586820a8007eb719134b33560
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/nvisii_utils.py
@@ -0,0 +1,123 @@
+import math
+import os
+
+import numpy as np
+import nvisii
+
+
+def load_object(
+ geom,
+ geom_name,
+ geom_type,
+ geom_quat,
+ geom_pos,
+ geom_size,
+ geom_scale,
+ geom_rgba,
+ geom_tex_name,
+ geom_tex_file,
+ class_id,
+ meshes,
+):
+ """
+ Function that initializes the meshes in the memory.
+
+ Args:
+ geom (XML element): Object in XML file to load
+
+ geom_name (str): Name for the object.
+
+ geom_type (str): Type of the object. Types include "box", "cylinder", or "mesh".
+
+ geom_quat (array): Quaternion (wxyz) of the object.
+
+ geom_pos (array): Position of the object.
+
+ geom_size (array): Size of the object.
+
+ geom_scale (array): Scale of the object.
+
+ geom_rgba (array): Color of the object. This is only used if the geom type is not
+ a mesh and there is no specified material.
+
+ geom_tex_name (str): Name of the texture for the object
+
+ geom_tex_file (str): File of the texture for the object
+
+ class_id (int) : Class id for the component
+
+ meshes (dict): Meshes for the object
+ """
+
+ primitive_types = ["box", "cylinder"]
+ component = None
+
+ if geom_type == "box":
+
+ component = nvisii.entity.create(
+ name=geom_name,
+ mesh=nvisii.mesh.create_box(name=geom_name, size=nvisii.vec3(geom_size[0], geom_size[1], geom_size[2])),
+ transform=nvisii.transform.create(geom_name),
+ material=nvisii.material.create(geom_name),
+ )
+
+ elif geom_type == "cylinder":
+
+ component = nvisii.entity.create(
+ name=geom_name,
+ mesh=nvisii.mesh.create_capped_cylinder(name=geom_name, radius=geom_size[0], size=geom_size[1]),
+ transform=nvisii.transform.create(geom_name),
+ material=nvisii.material.create(geom_name),
+ )
+
+ elif geom_type == "sphere":
+
+ component = nvisii.entity.create(
+ name=geom_name,
+ mesh=nvisii.mesh.create_sphere(name=geom_name, radius=geom_size[0]),
+ transform=nvisii.transform.create(geom_name),
+ material=nvisii.material.create(geom_name),
+ )
+
+ elif geom_type == "mesh":
+ filename = meshes[geom.attrib["mesh"]]["file"]
+ filename = os.path.splitext(filename)[0] + ".obj"
+
+ component = nvisii.import_scene(
+ file_path=filename,
+ position=nvisii.vec3(geom_pos[0], geom_pos[1], geom_pos[2]),
+ scale=(geom_scale[0], geom_scale[1], geom_scale[2]),
+ rotation=nvisii.quat(geom_quat[0], geom_quat[1], geom_quat[2], geom_quat[3]),
+ )
+
+ entity_ids = []
+ if isinstance(component, nvisii.scene):
+ for i in range(len(component.entities)):
+ entity_ids.append(component.entities[i].get_id())
+ else:
+ entity_ids.append(component.get_id())
+
+ if geom_type in primitive_types:
+ component.get_transform().set_position(nvisii.vec3(float(geom_pos[0]), float(geom_pos[1]), float(geom_pos[2])))
+
+ if geom_tex_file is not None and geom_tex_name is not None and geom_type != "mesh":
+
+ texture = nvisii.texture.get(geom_tex_name)
+
+ if texture is None:
+ texture = nvisii.texture.create_from_file(name=geom_tex_name, path=geom_tex_file)
+
+ component.get_material().set_base_color_texture(texture)
+ else:
+ if "gripper" in geom_name:
+ if geom_rgba is not None:
+ if isinstance(component, nvisii.scene):
+ for entity in component.entities:
+ entity.get_material().set_base_color(nvisii.vec3(geom_rgba[0], geom_rgba[1], geom_rgba[2]))
+ else:
+ component.get_material().set_base_color(nvisii.vec3(geom_rgba[0], geom_rgba[1], geom_rgba[2]))
+ elif "hand_visual" in geom_name:
+ for entity in component.entities:
+ entity.get_material().set_base_color(nvisii.vec3(0.05, 0.05, 0.05))
+
+ return component, entity_ids
diff --git a/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/parser.py b/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..b184cef1d29c7ab67408f5989b1a1cae67c874ab
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/renderers/nvisii/parser.py
@@ -0,0 +1,214 @@
+import xml.etree.ElementTree as ET
+from collections import namedtuple
+
+import numpy as np
+import nvisii
+
+from robosuite.renderers.base_parser import BaseParser
+from robosuite.renderers.nvisii.nvisii_utils import load_object
+from robosuite.utils.mjcf_utils import string_to_array
+
+Components = namedtuple(
+ "Components", ["obj", "geom_index", "element_id", "parent_body_name", "geom_pos", "geom_quat", "dynamic"]
+)
+
+
+class Parser(BaseParser):
+ def __init__(self, renderer, env, segmentation_type):
+ """
+ Parse the mujoco xml and initialize NVISII renderer objects.
+ Args:
+ env (Mujoco env): Environment to parse
+ """
+
+ super().__init__(renderer, env)
+ self.segmentation_type = segmentation_type
+ self.create_class_mapping()
+ self.components = {}
+
+ def parse_textures(self):
+ """
+ Parse and load all textures and store them
+ """
+
+ self.texture_attributes = {}
+ self.texture_id_mapping = {}
+
+ for texture in self.xml_root.iter("texture"):
+ texture_type = texture.get("type")
+ texture_name = texture.get("name")
+ texture_file = texture.get("file")
+ texture_rgb = texture.get("rgb1")
+
+ if texture_file is not None:
+ self.texture_attributes[texture_name] = texture.attrib
+ else:
+ color = np.array(string_to_array(texture_rgb))
+ self.texture_id_mapping[texture_name] = (color, texture_type)
+
+ def parse_materials(self):
+ """
+ Parse all materials and use texture mapping to initialize materials
+ """
+
+ self.material_texture_mapping = {}
+ for material in self.xml_root.iter("material"):
+ material_name = material.get("name")
+ texture_name = material.get("texture")
+ self.material_texture_mapping[material_name] = texture_name
+
+ def parse_meshes(self):
+ """
+ Create mapping of meshes.
+ """
+ self.meshes = {}
+ for mesh in self.xml_root.iter("mesh"):
+ self.meshes[mesh.get("name")] = mesh.attrib
+
+ def parse_geometries(self):
+ """
+ Iterate through each goemetry and load it in the NVISII renderer.
+ """
+ self.parse_meshes()
+ element_id = 0
+ repeated_names = {}
+ block_rendering_objects = ["VisualBread_g0", "VisualCan_g0", "VisualCereal_g0", "VisualMilk_g0"]
+
+ self.entity_id_class_mapping = {}
+
+ for geom_index, geom in enumerate(self.xml_root.iter("geom")):
+
+ parent_body = self.parent_map.get(geom)
+ parent_body_name = parent_body.get("name", "worldbody")
+
+ geom_name = geom.get("name")
+ geom_type = geom.get("type", "sphere")
+
+ rgba_str = geom.get("rgba")
+ geom_rgba = string_to_array(rgba_str) if rgba_str is not None else None
+
+ if geom_name is None:
+ if parent_body_name in repeated_names:
+ geom_name = parent_body_name + str(repeated_names[parent_body_name])
+ repeated_names[parent_body_name] += 1
+ else:
+ geom_name = parent_body_name + "0"
+ repeated_names[parent_body_name] = 1
+
+ if (geom.get("group") != "1" and geom_type != "plane") or ("collision" in geom_name):
+ continue
+
+ if "floor" in geom_name or "wall" in geom_name or geom_name in block_rendering_objects:
+ continue
+
+ geom_quat = string_to_array(geom.get("quat", "1 0 0 0"))
+ geom_quat = [geom_quat[0], geom_quat[1], geom_quat[2], geom_quat[3]]
+
+ # handling special case of bins arena
+ if "bin" in parent_body_name:
+ geom_pos = string_to_array(geom.get("pos", "0 0 0")) + string_to_array(parent_body.get("pos", "0 0 0"))
+ else:
+ geom_pos = string_to_array(geom.get("pos", "0 0 0"))
+
+ if geom_type == "mesh":
+ geom_scale = string_to_array(self.meshes[geom.get("mesh")].get("scale", "1 1 1"))
+ else:
+ geom_scale = [1, 1, 1]
+ geom_size = string_to_array(geom.get("size", "1 1 1"))
+
+ geom_mat = geom.get("material")
+
+ tags = ["bin"]
+ dynamic = True
+ if self.tag_in_name(geom_name, tags):
+ dynamic = False
+
+ geom_tex_name = None
+ geom_tex_file = None
+
+ if geom_mat is not None:
+ geom_tex_name = self.material_texture_mapping[geom_mat]
+
+ if geom_tex_name in self.texture_attributes:
+ geom_tex_file = self.texture_attributes[geom_tex_name]["file"]
+
+ class_id = self.get_class_id(geom_index, element_id)
+
+ # load obj into nvisii
+ obj, entity_ids = load_object(
+ geom=geom,
+ geom_name=geom_name,
+ geom_type=geom_type,
+ geom_quat=geom_quat,
+ geom_pos=geom_pos,
+ geom_size=geom_size,
+ geom_scale=geom_scale,
+ geom_rgba=geom_rgba,
+ geom_tex_name=geom_tex_name,
+ geom_tex_file=geom_tex_file,
+ class_id=class_id, # change
+ meshes=self.meshes,
+ )
+
+ element_id += 1
+
+ for entity_id in entity_ids:
+ self.entity_id_class_mapping[entity_id] = class_id
+
+ self.components[geom_name] = Components(
+ obj=obj,
+ geom_index=geom_index,
+ element_id=element_id,
+ parent_body_name=parent_body_name,
+ geom_pos=geom_pos,
+ geom_quat=geom_quat,
+ dynamic=dynamic,
+ )
+
+ self.max_elements = element_id
+
+ def create_class_mapping(self):
+ """
+ Create class name to index mapping for both semantic and instance
+ segmentation.
+ """
+ self.class2index = {}
+ for i, c in enumerate(self.env.model._classes_to_ids.keys()):
+ self.class2index[c] = i
+ self.class2index[None] = i + 1
+ self.max_classes = len(self.class2index)
+
+ self.instance2index = {}
+ for i, instance_class in enumerate(self.env.model._instances_to_ids.keys()):
+ self.instance2index[instance_class] = i
+ self.instance2index[None] = i + 1
+ self.max_instances = len(self.instance2index)
+
+ def get_class_id(self, geom_index, element_id):
+ """
+ Given index of the geom object get the class id based on
+ self.segmentation type.
+ """
+
+ if self.segmentation_type[0] == None or self.segmentation_type[0][0] == "element":
+ class_id = element_id
+ elif self.segmentation_type[0][0] == "class":
+ class_id = self.class2index[self.env.model._geom_ids_to_classes.get(geom_index)]
+ elif self.segmentation_type[0][0] == "instance":
+ class_id = self.instance2index[self.env.model._geom_ids_to_instances.get(geom_index)]
+
+ return class_id
+
+ def tag_in_name(self, name, tags):
+ """
+ Checks if one of the tags in body tags in the name
+
+ Args:
+ name (str): Name of geom element.
+
+ tags (array): List of keywords to check from.
+ """
+ for tag in tags:
+ if tag in name:
+ return True
+ return False
diff --git a/phantom/submodules/phantom-robosuite/robosuite/robots/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/robots/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c6296c86ab3b693a248fc06b29acc149edf10c7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/robots/__init__.py
@@ -0,0 +1,20 @@
+from .manipulator import Manipulator
+from .single_arm import SingleArm
+from .bimanual import Bimanual
+
+from robosuite.models.robots.robot_model import REGISTERED_ROBOTS
+
+ALL_ROBOTS = REGISTERED_ROBOTS.keys()
+
+# Robot class mappings -- must be maintained manually
+ROBOT_CLASS_MAPPING = {
+ "Baxter": Bimanual,
+ "IIWA": SingleArm,
+ "Jaco": SingleArm,
+ "Kinova3": SingleArm,
+ "Panda": SingleArm,
+ "Sawyer": SingleArm,
+ "UR5e": SingleArm,
+}
+
+BIMANUAL_ROBOTS = {k.lower() for k, v in ROBOT_CLASS_MAPPING.items() if v == Bimanual}
diff --git a/phantom/submodules/phantom-robosuite/robosuite/robots/bimanual.py b/phantom/submodules/phantom-robosuite/robosuite/robots/bimanual.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad8cbdce4032878e1fff1f957e4730df56f7c6df
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/robots/bimanual.py
@@ -0,0 +1,623 @@
+import copy
+import os
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.controllers import controller_factory, load_controller_config
+from robosuite.models.grippers import gripper_factory
+from robosuite.robots.manipulator import Manipulator
+from robosuite.utils.buffers import DeltaBuffer, RingBuffer
+from robosuite.utils.observables import Observable, sensor
+
+
+class Bimanual(Manipulator):
+ """
+ Initializes a bimanual robot simulation object.
+
+ Args:
+ robot_type (str): Specification for specific robot arm to be instantiated within this env (e.g: "Panda")
+
+ idn (int or str): Unique ID of this robot. Should be different from others
+
+ controller_config (dict or list of dict --> dict of dict): If set, contains relevant controller parameters
+ for creating custom controllers. Else, uses the default controller for this specific task. Should either
+ be single dict if same controller is to be used for both robot arms or else it should be a list of length 2.
+
+ :NOTE: In the latter case, assumes convention of [right, left]
+
+ initial_qpos (sequence of float): If set, determines the initial joint positions of the robot to be
+ instantiated for the task
+
+ initialization_noise (dict): Dict containing the initialization noise parameters. The expected keys and
+ corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to "None" or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ :Note: Specifying None will automatically create the required dict with "magnitude" set to 0.0
+
+ mount_type (str): type of mount, used to instantiate mount models from mount factory.
+ Default is "default", which is the default mount associated with this robot's corresponding model.
+ None results in no mount, and any other (valid) model overrides the default mount.
+
+ gripper_type (str or list of str --> dict): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default gripper associated
+ within the 'robot' specification. None removes the gripper, and any other (valid) model overrides the
+ default gripper. Should either be single str if same gripper type is to be used for both arms or else
+ it should be a list of length 2
+
+ :NOTE: In the latter case, assumes convention of [right, left]
+
+ control_freq (float): how many control signals to receive
+ in every second. This sets the amount of simulation time
+ that passes between every action input.
+ """
+
+ def __init__(
+ self,
+ robot_type: str,
+ idn=0,
+ controller_config=None,
+ initial_qpos=None,
+ initialization_noise=None,
+ mount_type="default",
+ gripper_type="default",
+ control_freq=20,
+ ):
+
+ self.controller = self._input2dict(None)
+ self.controller_config = self._input2dict(copy.deepcopy(controller_config))
+ self.gripper = self._input2dict(None)
+ self.gripper_type = self._input2dict(gripper_type)
+ self.has_gripper = self._input2dict([gripper_type is not None for _, gripper_type in self.gripper_type.items()])
+
+ self.gripper_joints = self._input2dict(None) # xml joint names for gripper
+ self._ref_gripper_joint_pos_indexes = self._input2dict(None) # xml gripper joint position indexes in mjsim
+ self._ref_gripper_joint_vel_indexes = self._input2dict(None) # xml gripper joint velocity indexes in mjsim
+ self._ref_joint_gripper_actuator_indexes = self._input2dict(
+ None
+ ) # xml gripper (pos) actuator indexes for robot in mjsim
+ self.eef_rot_offset = self._input2dict(None) # rotation offsets from final arm link to gripper (quat)
+ self.eef_site_id = self._input2dict(None) # xml element id for eef in mjsim
+ self.eef_cylinder_id = self._input2dict(None) # xml element id for eef cylinder in mjsim
+ self.torques = None # Current torques being applied
+
+ self.recent_ee_forcetorques = self._input2dict(None) # Current and last forces / torques sensed at eef
+ self.recent_ee_pose = self._input2dict(None) # Current and last eef pose (pos + ori (quat))
+ self.recent_ee_vel = self._input2dict(None) # Current and last eef velocity
+ self.recent_ee_vel_buffer = self._input2dict(None) # RingBuffer holding prior 10 values of velocity values
+ self.recent_ee_acc = self._input2dict(None) # Current and last eef acceleration
+
+ super().__init__(
+ robot_type=robot_type,
+ idn=idn,
+ initial_qpos=initial_qpos,
+ initialization_noise=initialization_noise,
+ mount_type=mount_type,
+ control_freq=control_freq,
+ )
+
+ def _load_controller(self):
+ """
+ Loads controller to be used for dynamic trajectories
+ """
+ # Flag for loading urdf once (only applicable for IK controllers)
+ urdf_loaded = False
+
+ # Load controller configs for both left and right arm
+ for arm in self.arms:
+ # First, load the default controller if none is specified
+ if not self.controller_config[arm]:
+ # Need to update default for a single agent
+ controller_path = os.path.join(
+ os.path.dirname(__file__),
+ "..",
+ "controllers/config/{}.json".format(self.robot_model.default_controller_config[arm]),
+ )
+ self.controller_config[arm] = load_controller_config(custom_fpath=controller_path)
+
+ # Assert that the controller config is a dict file:
+ # NOTE: "type" must be one of: {JOINT_POSITION, JOINT_TORQUE, JOINT_VELOCITY,
+ # OSC_POSITION, OSC_POSE, IK_POSE}
+ assert (
+ type(self.controller_config[arm]) == dict
+ ), "Inputted controller config must be a dict! Instead, got type: {}".format(
+ type(self.controller_config[arm])
+ )
+
+ # Add to the controller dict additional relevant params:
+ # the robot name, mujoco sim, eef_name, actuator_range, joint_indexes, timestep (model) freq,
+ # policy (control) freq, and ndim (# joints)
+ self.controller_config[arm]["robot_name"] = self.name
+ self.controller_config[arm]["sim"] = self.sim
+ self.controller_config[arm]["eef_name"] = self.gripper[arm].important_sites["grip_site"]
+ self.controller_config[arm]["eef_rot_offset"] = self.eef_rot_offset[arm]
+ self.controller_config[arm]["ndim"] = self._joint_split_idx
+ self.controller_config[arm]["policy_freq"] = self.control_freq
+ (start, end) = (None, self._joint_split_idx) if arm == "right" else (self._joint_split_idx, None)
+ self.controller_config[arm]["joint_indexes"] = {
+ "joints": self.joint_indexes[start:end],
+ "qpos": self._ref_joint_pos_indexes[start:end],
+ "qvel": self._ref_joint_vel_indexes[start:end],
+ }
+ self.controller_config[arm]["actuator_range"] = (
+ self.torque_limits[0][start:end],
+ self.torque_limits[1][start:end],
+ )
+
+ # Only load urdf the first time this controller gets called
+ self.controller_config[arm]["load_urdf"] = True if not urdf_loaded else False
+ urdf_loaded = True
+
+ # Instantiate the relevant controller
+ self.controller[arm] = controller_factory(self.controller_config[arm]["type"], self.controller_config[arm])
+
+ def load_model(self):
+ """
+ Loads robot and optionally add grippers.
+ """
+ # First, run the superclass method to load the relevant model
+ super().load_model()
+
+ # Verify that the loaded model is of the correct type for this robot
+ if self.robot_model.arm_type != "bimanual":
+ raise TypeError(
+ "Error loading robot model: Incompatible arm type specified for this robot. "
+ "Requested model arm type: {}, robot arm type: {}".format(self.robot_model.arm_type, type(self))
+ )
+
+ # Now, load the gripper if necessary
+ for arm in self.arms:
+ if self.has_gripper[arm]:
+ if self.gripper_type[arm] == "default":
+ # Load the default gripper from the robot file
+ self.gripper[arm] = gripper_factory(
+ self.robot_model.default_gripper[arm], idn="_".join((str(self.idn), arm))
+ )
+ else:
+ # Load user-specified gripper
+ self.gripper[arm] = gripper_factory(self.gripper_type[arm], idn="_".join((str(self.idn), arm)))
+ else:
+ # Load null gripper
+ self.gripper[arm] = gripper_factory(None, idn="_".join((str(self.idn), arm)))
+ # Grab eef rotation offset
+ self.eef_rot_offset[arm] = T.quat_multiply(
+ self.robot_model.hand_rotation_offset[arm], self.gripper[arm].rotation_offset
+ )
+ # Add this gripper to the robot model
+ self.robot_model.add_gripper(self.gripper[arm], self.robot_model.eef_name[arm])
+
+ def reset(self, deterministic=False):
+ """
+ Sets initial pose of arm and grippers. Overrides gripper joint configuration if we're using a
+ deterministic reset (e.g.: hard reset from xml file)
+
+ Args:
+ deterministic (bool): If true, will not randomize initializations within the sim
+ """
+ # First, run the superclass method to reset the position and controller
+ super().reset(deterministic)
+
+ # Setup arm-specific values
+ for arm in self.arms:
+ # Now, reset the grippers if necessary
+ if self.has_gripper[arm]:
+ if not deterministic:
+ self.sim.data.qpos[self._ref_gripper_joint_pos_indexes[arm]] = self.gripper[arm].init_qpos
+
+ self.gripper[arm].current_action = np.zeros(self.gripper[arm].dof)
+
+ # Update base pos / ori references in controller (technically only needs to be called once)
+ self.controller[arm].update_base_pose(self.base_pos, self.base_ori)
+ # Setup buffers for eef values
+ self.recent_ee_forcetorques[arm] = DeltaBuffer(dim=6)
+ self.recent_ee_pose[arm] = DeltaBuffer(dim=7)
+ self.recent_ee_vel[arm] = DeltaBuffer(dim=6)
+ self.recent_ee_vel_buffer[arm] = RingBuffer(dim=6, length=10)
+ self.recent_ee_acc[arm] = DeltaBuffer(dim=6)
+
+ def setup_references(self):
+ """
+ Sets up necessary reference for robots, grippers, and objects.
+
+ Note that this should get called during every reset from the environment
+ """
+ # First, run the superclass method to setup references for joint-related values / indexes
+ super().setup_references()
+
+ # Now, add references to gripper if necessary
+ # indices for grippers in qpos, qvel
+ for arm in self.arms:
+ if self.has_gripper[arm]:
+ self.gripper_joints[arm] = list(self.gripper[arm].joints)
+ self._ref_gripper_joint_pos_indexes[arm] = [
+ self.sim.model.get_joint_qpos_addr(x) for x in self.gripper_joints[arm]
+ ]
+ self._ref_gripper_joint_vel_indexes[arm] = [
+ self.sim.model.get_joint_qvel_addr(x) for x in self.gripper_joints[arm]
+ ]
+ self._ref_joint_gripper_actuator_indexes[arm] = [
+ self.sim.model.actuator_name2id(actuator) for actuator in self.gripper[arm].actuators
+ ]
+
+ # IDs of sites for eef visualization
+ self.eef_site_id[arm] = self.sim.model.site_name2id(self.gripper[arm].important_sites["grip_site"])
+ self.eef_cylinder_id[arm] = self.sim.model.site_name2id(self.gripper[arm].important_sites["grip_cylinder"])
+
+ def control(self, action, policy_step=False):
+ """
+ Actuate the robot with the
+ passed joint velocities and gripper control.
+
+ Args:
+ action (np.array): The control to apply to the robot. The first @self.robot_model.dof dimensions should
+ be the desired normalized joint velocities and if the robot has a gripper, the next @self.gripper.dof
+ dimensions should be actuation controls for the gripper.
+
+ :NOTE: Assumes inputted actions are of form:
+ [right_arm_control, right_gripper_control, left_arm_control, left_gripper_control]
+
+ policy_step (bool): Whether a new policy step (action) is being taken
+
+ Raises:
+ AssertionError: [Invalid action dimension]
+ """
+ # clip actions into valid range
+ assert len(action) == self.action_dim, "environment got invalid action dimension -- expected {}, got {}".format(
+ self.action_dim, len(action)
+ )
+
+ self.torques = np.array([])
+ # Now execute actions for each arm
+ for arm in self.arms:
+ # Make sure to split action space correctly
+ (start, end) = (None, self._action_split_idx) if arm == "right" else (self._action_split_idx, None)
+ sub_action = action[start:end]
+
+ gripper_action = None
+ if self.has_gripper[arm]:
+ # get all indexes past controller dimension indexes
+ gripper_action = sub_action[self.controller[arm].control_dim :]
+ sub_action = sub_action[: self.controller[arm].control_dim]
+
+ # Update the controller goal if this is a new policy step
+ if policy_step:
+ self.controller[arm].set_goal(sub_action)
+
+ # Now run the controller for a step and add it to the torques
+ self.torques = np.concatenate((self.torques, self.controller[arm].run_controller()))
+
+ # Get gripper action, if applicable
+ if self.has_gripper[arm]:
+ self.grip_action(gripper=self.gripper[arm], gripper_action=gripper_action)
+
+ # Clip the torques
+ low, high = self.torque_limits
+ self.torques = np.clip(self.torques, low, high)
+
+ # Apply joint torque control
+ self.sim.data.ctrl[self._ref_joint_actuator_indexes] = self.torques
+
+ # If this is a policy step, also update buffers holding recent values of interest
+ if policy_step:
+ # Update proprioceptive values
+ self.recent_qpos.push(self._joint_positions)
+ self.recent_actions.push(action)
+ self.recent_torques.push(self.torques)
+
+ for arm in self.arms:
+ # Update arm-specific proprioceptive values
+ self.recent_ee_forcetorques[arm].push(np.concatenate((self.ee_force[arm], self.ee_torque[arm])))
+ self.recent_ee_pose[arm].push(
+ np.concatenate((self.controller[arm].ee_pos, T.mat2quat(self.controller[arm].ee_ori_mat)))
+ )
+ self.recent_ee_vel[arm].push(
+ np.concatenate((self.controller[arm].ee_pos_vel, self.controller[arm].ee_ori_vel))
+ )
+
+ # Estimation of eef acceleration (averaged derivative of recent velocities)
+ self.recent_ee_vel_buffer[arm].push(
+ np.concatenate((self.controller[arm].ee_pos_vel, self.controller[arm].ee_ori_vel))
+ )
+ diffs = np.vstack(
+ [
+ self.recent_ee_acc[arm].current,
+ self.control_freq * np.diff(self.recent_ee_vel_buffer[arm].buf, axis=0),
+ ]
+ )
+ ee_acc = np.array([np.convolve(col, np.ones(10) / 10.0, mode="valid")[0] for col in diffs.transpose()])
+ self.recent_ee_acc[arm].push(ee_acc)
+
+ def _visualize_grippers(self, visible):
+ """
+ Visualizes the gripper site(s) if applicable.
+
+ Args:
+ visible (bool): True if visualizing the gripper for this arm.
+ """
+ for arm in self.arms:
+ self.gripper[arm].set_sites_visibility(sim=self.sim, visible=visible)
+
+ def setup_observables(self):
+ """
+ Sets up observables to be used for this robot
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ # Get general robot observables first
+ observables = super().setup_observables()
+
+ # Get prefix from robot model to avoid naming clashes for multiple robots and define observables modality
+ pf = self.robot_model.naming_prefix
+ modality = f"{pf}proprio"
+ sensors = []
+ names = []
+
+ for arm in self.arms:
+ # Add in eef info
+ arm_sensors, arm_sensor_names = self._create_arm_sensors(arm=arm, modality=modality)
+ sensors += arm_sensors
+ names += arm_sensor_names
+
+ # Create observables for this robot
+ for name, s in zip(names, sensors):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ )
+
+ return observables
+
+ def _create_arm_sensors(self, arm, modality):
+ """
+ Helper function to create sensors for a given arm. This is abstracted in a separate function call so that we
+ don't have local function naming collisions during the _setup_observables() call.
+
+ Args:
+ arm (str): Arm to create sensors for
+ modality (str): Modality to assign to all sensors
+
+ Returns:
+ 2-tuple:
+ sensors (list): Array of sensors for the given arm
+ names (list): array of corresponding observable names
+ """
+ pf = self.robot_model.naming_prefix
+
+ # eef features
+ @sensor(modality=modality)
+ def eef_pos(obs_cache):
+ return np.array(self.sim.data.site_xpos[self.eef_site_id[arm]])
+
+ @sensor(modality=modality)
+ def eef_quat(obs_cache):
+ return T.convert_quat(self.sim.data.get_body_xquat(self.robot_model.eef_name[arm]), to="xyzw")
+
+ sensors = [eef_pos, eef_quat]
+ names = [f"{pf}{arm}_eef_pos", f"{pf}{arm}_eef_quat"]
+
+ # add in gripper sensors if this robot has a gripper
+ if self.has_gripper[arm]:
+
+ @sensor(modality=modality)
+ def gripper_qpos(obs_cache):
+ return np.array([self.sim.data.qpos[x] for x in self._ref_gripper_joint_pos_indexes[arm]])
+
+ @sensor(modality=modality)
+ def gripper_qvel(obs_cache):
+ return np.array([self.sim.data.qvel[x] for x in self._ref_gripper_joint_vel_indexes[arm]])
+
+ sensors += [gripper_qpos, gripper_qvel]
+ names += [f"{pf}{arm}_gripper_qpos", f"{pf}{arm}_gripper_qvel"]
+
+ return sensors, names
+
+ def _input2dict(self, inp):
+ """
+ Helper function that converts an input that is either a single value or a list into a dict with keys for
+ each arm: "right", "left"
+
+ Args:
+ inp (str or list or None): Input value to be converted to dict
+
+ :Note: If inp is a list, then assumes format is [right, left]
+
+ Returns:
+ dict: Inputs mapped for each robot arm
+ """
+ # First, convert to list if necessary
+ if type(inp) is not list:
+ inp = [inp for _ in range(2)]
+ # Now, convert list to dict and return
+ return {key: value for key, value in zip(self.arms, inp)}
+
+ @property
+ def arms(self):
+ """
+ Returns name of arms used as naming convention throughout this module
+
+ Returns:
+ 2-tuple: ('right', 'left')
+ """
+ return "right", "left"
+
+ @property
+ def action_limits(self):
+ """
+ Action lower/upper limits per dimension.
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum (low) action values
+ - (np.array) maximum (high) action values
+ """
+ # Action limits based on controller limits
+ low, high = [], []
+ for arm in self.arms:
+ low_g, high_g = (
+ ([-1] * self.gripper[arm].dof, [1] * self.gripper[arm].dof) if self.has_gripper[arm] else ([], [])
+ )
+ low_c, high_c = self.controller[arm].control_limits
+ low, high = np.concatenate([low, low_c, low_g]), np.concatenate([high, high_c, high_g])
+ return low, high
+
+ @property
+ def ee_ft_integral(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the integral over time of the applied ee force-torque for that arm
+ """
+ vals = {}
+ for arm in self.arms:
+ vals[arm] = np.abs((1.0 / self.control_freq) * self.recent_ee_forcetorques[arm].average)
+ return vals
+
+ @property
+ def ee_force(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the force applied at the force sensor at the robot arm's eef
+ """
+ vals = {}
+ for arm in self.arms:
+ vals[arm] = self.get_sensor_measurement(self.gripper[arm].important_sensors["force_ee"])
+ return vals
+
+ @property
+ def ee_torque(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the torque applied at the torque sensor at the robot arm's eef
+ """
+ vals = {}
+ for arm in self.arms:
+ vals[arm] = self.get_sensor_measurement(self.gripper[arm].important_sensors["torque_ee"])
+ return vals
+
+ @property
+ def _hand_pose(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the eef pose in base frame of robot.
+ """
+ vals = {}
+ for arm in self.arms:
+ vals[arm] = self.pose_in_base_from_name(self.robot_model.eef_name[arm])
+ return vals
+
+ @property
+ def _hand_quat(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the eef quaternion in base frame of robot.
+ """
+ vals = {}
+ orns = self._hand_orn
+ for arm in self.arms:
+ vals[arm] = T.mat2quat(orns[arm])
+ return vals
+
+ @property
+ def _hand_total_velocity(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the total eef velocity (linear + angular) in the base frame
+ as a numpy array of shape (6,)
+ """
+ vals = {}
+ for arm in self.arms:
+ # Determine correct start, end points based on arm
+ (start, end) = (None, self._joint_split_idx) if arm == "right" else (self._joint_split_idx, None)
+
+ # Use jacobian to translate joint velocities to end effector velocities.
+ Jp = self.sim.data.get_body_jacp(self.robot_model.eef_name[arm]).reshape((3, -1))
+ Jp_joint = Jp[:, self._ref_joint_vel_indexes[start:end]]
+
+ Jr = self.sim.data.get_body_jacr(self.robot_model.eef_name[arm]).reshape((3, -1))
+ Jr_joint = Jr[:, self._ref_joint_vel_indexes[start:end]]
+
+ eef_lin_vel = Jp_joint.dot(self._joint_velocities)
+ eef_rot_vel = Jr_joint.dot(self._joint_velocities)
+ vals[arm] = np.concatenate([eef_lin_vel, eef_rot_vel])
+ return vals
+
+ @property
+ def _hand_pos(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the position of eef in base frame of robot.
+ """
+ vals = {}
+ poses = self._hand_pose
+ for arm in self.arms:
+ eef_pose_in_base = poses[arm]
+ vals[arm] = eef_pose_in_base[:3, 3]
+ return vals
+
+ @property
+ def _hand_orn(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the orientation of eef in base frame of robot as a rotation matrix.
+ """
+ vals = {}
+ poses = self._hand_pose
+ for arm in self.arms:
+ eef_pose_in_base = poses[arm]
+ vals[arm] = eef_pose_in_base[:3, :3]
+ return vals
+
+ @property
+ def _hand_vel(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the velocity of eef in base frame of robot.
+ """
+ vels = self._hand_total_velocity
+ for arm in self.arms:
+ vels[arm] = vels[arm][:3]
+ return vels
+
+ @property
+ def _hand_ang_vel(self):
+ """
+ Returns:
+ dict: each arm-specific entry specifies the angular velocity of eef in base frame of robot.
+ """
+ vels = self._hand_total_velocity
+ for arm in self.arms:
+ vels[arm] = vels[arm][3:]
+ return vels
+
+ @property
+ def _action_split_idx(self):
+ """
+ Grabs the index that correctly splits the right arm from the left arm actions
+
+ :NOTE: Assumes inputted actions are of form:
+ [right_arm_control, right_gripper_control, left_arm_control, left_gripper_control]
+
+ Returns:
+ int: Index splitting right from left arm actions
+ """
+ return (
+ self.controller["right"].control_dim + self.gripper["right"].dof
+ if self.has_gripper["right"]
+ else self.controller["right"].control_dim
+ )
+
+ @property
+ def _joint_split_idx(self):
+ """
+ Returns:
+ int: the index that correctly splits the right arm from the left arm joints
+ """
+ return int(len(self.robot_joints) / 2)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/robots/manipulator.py b/phantom/submodules/phantom-robosuite/robosuite/robots/manipulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..b34e6c56b607dc15fe07823f4ac29345a860d986
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/robots/manipulator.py
@@ -0,0 +1,164 @@
+from robosuite.robots.robot import Robot
+
+
+class Manipulator(Robot):
+ """
+ Initializes a manipulator robot simulation object, as defined by a single corresponding robot arm XML and
+ associated gripper XML
+ """
+
+ def _load_controller(self):
+ raise NotImplementedError
+
+ def control(self, action, policy_step=False):
+ raise NotImplementedError
+
+ def grip_action(self, gripper, gripper_action):
+ """
+ Executes @gripper_action for specified @gripper
+
+ Args:
+ gripper (GripperModel): Gripper to execute action for
+ gripper_action (float): Value between [-1,1] to send to gripper
+ """
+ actuator_idxs = [self.sim.model.actuator_name2id(actuator) for actuator in gripper.actuators]
+ if self.direct_gripper_control:
+ if "Robotiq85" in gripper.name:
+ applied_gripper_action = gripper_action[0]
+ else:
+ applied_gripper_action = [gripper_action[0], -gripper_action[0]]
+ else:
+ gripper_action_actual = gripper.format_action(gripper_action)
+ # rescale normalized gripper action to control ranges
+ ctrl_range = self.sim.model.actuator_ctrlrange[actuator_idxs]
+ bias = 0.5 * (ctrl_range[:, 1] + ctrl_range[:, 0])
+ weight = 0.5 * (ctrl_range[:, 1] - ctrl_range[:, 0])
+ applied_gripper_action = bias + weight * gripper_action_actual
+ self.sim.data.ctrl[actuator_idxs] = applied_gripper_action
+
+ def visualize(self, vis_settings):
+ """
+ Do any necessary visualization for this manipulator
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "robots" and "grippers" keyword as well as any other
+ robot-specific options specified.
+ """
+ super().visualize(vis_settings=vis_settings)
+ self._visualize_grippers(visible=vis_settings["grippers"])
+
+ def _visualize_grippers(self, visible):
+ """
+ Visualizes the gripper site(s) if applicable.
+
+ Args:
+ visible (bool): True if visualizing grippers, else False
+ """
+ raise NotImplementedError
+
+ @property
+ def action_limits(self):
+ raise NotImplementedError
+
+ @property
+ def dof(self):
+ """
+ Returns:
+ int: degrees of freedom of the robot (with grippers).
+ """
+ # Get the dof of the base robot model
+ dof = super().dof
+ for gripper in self.robot_model.grippers.values():
+ dof += gripper.dof
+ return dof
+
+ @property
+ def ee_ft_integral(self):
+ """
+ Returns:
+ float or dict: either single value or arm-specific entries specifying the integral over time of the applied
+ ee force-torque for that arm
+ """
+ raise NotImplementedError
+
+ @property
+ def ee_force(self):
+ """
+ Returns:
+ np.array or dict: either single value or arm-specific entries specifying the force applied at the force sensor
+ at the robot arm's eef
+ """
+ raise NotImplementedError
+
+ @property
+ def ee_torque(self):
+ """
+ Returns:
+ np.array or dict: either single value or arm-specific entries specifying the torque applied at the torque
+ sensor at the robot arm's eef
+ """
+ raise NotImplementedError
+
+ @property
+ def _hand_pose(self):
+ """
+ Returns:
+ np.array or dict: either single value or arm-specific entries specifying the eef pose in base frame of
+ robot.
+ """
+ raise NotImplementedError
+
+ @property
+ def _hand_quat(self):
+ """
+ Returns:
+ np.array or dict: either single value or arm-specific entries specifying the eef quaternion in base frame
+ of robot.
+ """
+ raise NotImplementedError
+
+ @property
+ def _hand_total_velocity(self):
+ """
+ Returns:
+ np.array or dict: either single value or arm-specific entries specifying the total eef velocity
+ (linear + angular) in the base frame as a numpy array of shape (6,)
+ """
+ raise NotImplementedError
+
+ @property
+ def _hand_pos(self):
+ """
+ Returns:
+ np.array or dict: either single value or arm-specific entries specifying the position of eef in base frame
+ of robot.
+ """
+ raise NotImplementedError
+
+ @property
+ def _hand_orn(self):
+ """
+ Returns:
+ np.array or dict: either single value or arm-specific entries specifying the orientation of eef in base
+ frame of robot as a rotation matrix.
+ """
+ raise NotImplementedError
+
+ @property
+ def _hand_vel(self):
+ """
+ Returns:
+ np.array or dict: either single value or arm-specific entries specifying the velocity of eef in base frame
+ of robot.
+ """
+ raise NotImplementedError
+
+ @property
+ def _hand_ang_vel(self):
+ """
+ Returns:
+ np.array or dict: either single value or arm-specific entries specifying the angular velocity of eef in
+ base frame of robot.
+ """
+ raise NotImplementedError
diff --git a/phantom/submodules/phantom-robosuite/robosuite/robots/robot.py b/phantom/submodules/phantom-robosuite/robosuite/robots/robot.py
new file mode 100644
index 0000000000000000000000000000000000000000..e31586aef77f26d92ff26f01130899ee78a7c7e7
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/robots/robot.py
@@ -0,0 +1,387 @@
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite.macros as macros
+import robosuite.utils.transform_utils as T
+from robosuite.models.mounts import mount_factory
+from robosuite.models.robots import create_robot
+from robosuite.utils.binding_utils import MjSim
+from robosuite.utils.buffers import DeltaBuffer
+from robosuite.utils.observables import Observable, sensor
+
+
+class Robot(object):
+ """
+ Initializes a robot simulation object, as defined by a single corresponding robot XML
+
+ Args:
+ robot_type (str): Specification for specific robot arm to be instantiated within this env (e.g: "Panda")
+
+ idn (int or str): Unique ID of this robot. Should be different from others
+
+ initial_qpos (sequence of float): If set, determines the initial joint positions of the robot to be
+ instantiated for the task
+
+ initialization_noise (dict): Dict containing the initialization noise parameters. The expected keys and
+ corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to "None" or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ :Note: Specifying None will automatically create the required dict with "magnitude" set to 0.0
+
+ mount_type (str): type of mount, used to instantiate mount models from mount factory.
+ Default is "default", which is the default mount associated with this robot's corresponding model.
+ None results in no mount, and any other (valid) model overrides the default mount.
+
+ control_freq (float): how many control signals to receive
+ in every second. This sets the amount of simulation time
+ that passes between every action input.
+ """
+
+ def __init__(
+ self,
+ robot_type: str,
+ idn=0,
+ initial_qpos=None,
+ initialization_noise=None,
+ mount_type="default",
+ control_freq=20,
+ ):
+ # Set relevant attributes
+ self.sim = None # MjSim this robot is tied to
+ self.name = robot_type # Specific robot to instantiate
+ self.idn = idn # Unique ID of this robot
+ self.robot_model = None # object holding robot model-specific info
+ self.control_freq = control_freq # controller Hz
+ self.mount_type = mount_type # Type of mount to use
+
+ # Scaling of Gaussian initial noise applied to robot joints
+ self.initialization_noise = initialization_noise
+ if self.initialization_noise is None:
+ self.initialization_noise = {"magnitude": 0.0, "type": "gaussian"} # no noise conditions
+ elif self.initialization_noise == "default":
+ self.initialization_noise = {"magnitude": 0.02, "type": "gaussian"}
+ self.initialization_noise["magnitude"] = (
+ self.initialization_noise["magnitude"] if self.initialization_noise["magnitude"] else 0.0
+ )
+
+ self.init_qpos = initial_qpos # n-dim list / array of robot joints
+
+ self.robot_joints = None # xml joint names for robot
+ self.base_pos = None # Base position in world coordinates (x,y,z)
+ self.base_ori = None # Base rotation in world coordinates (x,y,z,w quat)
+ self._ref_joint_indexes = None # xml joint indexes for robot in mjsim
+ self._ref_joint_pos_indexes = None # xml joint position indexes in mjsim
+ self._ref_joint_vel_indexes = None # xml joint velocity indexes in mjsim
+ self._ref_joint_actuator_indexes = None # xml joint (torq) actuator indexes for robot in mjsim
+
+ self.recent_qpos = None # Current and last robot arm qpos
+ self.recent_actions = None # Current and last action applied
+ self.recent_torques = None # Current and last torques applied
+
+ def _load_controller(self):
+ """
+ Loads controller to be used for dynamic trajectories.
+ """
+ raise NotImplementedError
+
+ def load_model(self):
+ """
+ Loads robot and optionally add grippers.
+ """
+ self.robot_model = create_robot(self.name, idn=self.idn)
+
+ # Add mount if specified
+ if self.mount_type == "default":
+ self.robot_model.add_mount(mount=mount_factory(self.robot_model.default_mount, idn=self.idn))
+ else:
+ self.robot_model.add_mount(mount=mount_factory(self.mount_type, idn=self.idn))
+
+ # Use default from robot model for initial joint positions if not specified
+ if self.init_qpos is None:
+ self.init_qpos = self.robot_model.init_qpos
+
+ def reset_sim(self, sim: MjSim):
+ """
+ Replaces current sim with a new sim
+
+ Args:
+ sim (MjSim): New simulation being instantiated to replace the old one
+ """
+ self.sim = sim
+
+ def reset(self, deterministic=False):
+ """
+ Sets initial pose of arm and grippers. Overrides robot joint configuration if we're using a
+ deterministic reset (e.g.: hard reset from xml file)
+
+ Args:
+ deterministic (bool): If true, will not randomize initializations within the sim
+
+ Raises:
+ ValueError: [Invalid noise type]
+ """
+ init_qpos = np.array(self.init_qpos)
+ if not deterministic:
+ # Determine noise
+ if self.initialization_noise["type"] == "gaussian":
+ noise = np.random.randn(len(self.init_qpos)) * self.initialization_noise["magnitude"]
+ elif self.initialization_noise["type"] == "uniform":
+ noise = np.random.uniform(-1.0, 1.0, len(self.init_qpos)) * self.initialization_noise["magnitude"]
+ else:
+ raise ValueError("Error: Invalid noise type specified. Options are 'gaussian' or 'uniform'.")
+ init_qpos += noise
+
+ # Set initial position in sim
+ self.sim.data.qpos[self._ref_joint_pos_indexes] = init_qpos
+
+ # Load controllers
+ self._load_controller()
+
+ # Update base pos / ori references
+ self.base_pos = self.sim.data.get_body_xpos(self.robot_model.root_body)
+ self.base_ori = T.mat2quat(self.sim.data.get_body_xmat(self.robot_model.root_body).reshape((3, 3)))
+
+ # Setup buffers to hold recent values
+ self.recent_qpos = DeltaBuffer(dim=len(self.joint_indexes))
+ self.recent_actions = DeltaBuffer(dim=self.action_dim)
+ self.recent_torques = DeltaBuffer(dim=len(self.joint_indexes))
+
+ def setup_references(self):
+ """
+ Sets up necessary reference for robots, grippers, and objects.
+ """
+ # indices for joints in qpos, qvel
+ self.robot_joints = self.robot_model.joints
+ self._ref_joint_pos_indexes = [self.sim.model.get_joint_qpos_addr(x) for x in self.robot_joints]
+ self._ref_joint_vel_indexes = [self.sim.model.get_joint_qvel_addr(x) for x in self.robot_joints]
+
+ # indices for joint indexes
+ self._ref_joint_indexes = [self.sim.model.joint_name2id(joint) for joint in self.robot_model.joints]
+
+ # indices for joint pos actuation, joint vel actuation, gripper actuation
+ self._ref_joint_actuator_indexes = [
+ self.sim.model.actuator_name2id(actuator) for actuator in self.robot_model.actuators
+ ]
+
+ def setup_observables(self):
+ """
+ Sets up observables to be used for this robot
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ # Get prefix from robot model to avoid naming clashes for multiple robots and define observables modality
+ pf = self.robot_model.naming_prefix
+ pre_compute = f"{pf}joint_pos"
+ modality = f"{pf}proprio"
+
+ # proprioceptive features
+ @sensor(modality=modality)
+ def joint_pos(obs_cache):
+ return np.array([self.sim.data.qpos[x] for x in self._ref_joint_pos_indexes])
+
+ @sensor(modality=modality)
+ def joint_pos_cos(obs_cache):
+ return np.cos(obs_cache[pre_compute]) if pre_compute in obs_cache else np.zeros(self.robot_model.dof)
+
+ @sensor(modality=modality)
+ def joint_pos_sin(obs_cache):
+ return np.sin(obs_cache[pre_compute]) if pre_compute in obs_cache else np.zeros(self.robot_model.dof)
+
+ @sensor(modality=modality)
+ def joint_vel(obs_cache):
+ return np.array([self.sim.data.qvel[x] for x in self._ref_joint_vel_indexes])
+
+ sensors = [joint_pos, joint_pos_cos, joint_pos_sin, joint_vel]
+ names = ["joint_pos", "joint_pos_cos", "joint_pos_sin", "joint_vel"]
+ # We don't want to include the direct joint pos sensor outputs
+ actives = [False, True, True, True]
+
+ # Create observables for this robot
+ observables = OrderedDict()
+ for name, s, active in zip(names, sensors, actives):
+ obs_name = pf + name
+ observables[obs_name] = Observable(
+ name=obs_name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ active=active,
+ )
+
+ return observables
+
+ def control(self, action, policy_step=False):
+ """
+ Actuate the robot with the
+ passed joint velocities and gripper control.
+
+ Args:
+ action (np.array): The control to apply to the robot. The first @self.robot_model.dof dimensions should
+ be the desired normalized joint velocities and if the robot has a gripper, the next @self.gripper.dof
+ dimensions should be actuation controls for the gripper.
+ policy_step (bool): Whether a new policy step (action) is being taken
+ """
+ raise NotImplementedError
+
+ def check_q_limits(self):
+ """
+ Check if this robot is either very close or at the joint limits
+
+ Returns:
+ bool: True if this arm is near its joint limits
+ """
+ tolerance = 0.1
+ for (qidx, (q, q_limits)) in enumerate(
+ zip(self.sim.data.qpos[self._ref_joint_pos_indexes], self.sim.model.jnt_range[self._ref_joint_indexes])
+ ):
+ if q_limits[0] != q_limits[1] and not (q_limits[0] + tolerance < q < q_limits[1] - tolerance):
+ print("Joint limit reached in joint " + str(qidx))
+ return True
+ return False
+
+ def visualize(self, vis_settings):
+ """
+ Do any necessary visualization for this robot
+
+ Args:
+ vis_settings (dict): Visualization keywords mapped to T/F, determining whether that specific
+ component should be visualized. Should have "robots" keyword as well as any other robot-specific
+ options specified.
+ """
+ self.robot_model.set_sites_visibility(sim=self.sim, visible=vis_settings["robots"])
+
+ @property
+ def action_limits(self):
+ """
+ Action lower/upper limits per dimension.
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum (low) action values
+ - (np.array) maximum (high) action values
+ """
+ raise NotImplementedError
+
+ @property
+ def torque_limits(self):
+ """
+ Torque lower/upper limits per dimension.
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum (low) torque values
+ - (np.array) maximum (high) torque values
+ """
+ # Torque limit values pulled from relevant robot.xml file
+ low = self.sim.model.actuator_ctrlrange[self._ref_joint_actuator_indexes, 0]
+ high = self.sim.model.actuator_ctrlrange[self._ref_joint_actuator_indexes, 1]
+
+ return low, high
+
+ @property
+ def action_dim(self):
+ """
+ Action space dimension for this robot
+ """
+ return self.action_limits[0].shape[0]
+
+ @property
+ def dof(self):
+ """
+ Returns:
+ int: the active DoF of the robot (Number of robot joints + active gripper DoF).
+ """
+ dof = self.robot_model.dof
+ return dof
+
+ def pose_in_base_from_name(self, name):
+ """
+ A helper function that takes in a named data field and returns the pose
+ of that object in the base frame.
+
+ Args:
+ name (str): Name of body in sim to grab pose
+
+ Returns:
+ np.array: (4,4) array corresponding to the pose of @name in the base frame
+ """
+
+ pos_in_world = self.sim.data.get_body_xpos(name)
+ rot_in_world = self.sim.data.get_body_xmat(name).reshape((3, 3))
+ pose_in_world = T.make_pose(pos_in_world, rot_in_world)
+
+ base_pos_in_world = self.sim.data.get_body_xpos(self.robot_model.root_body)
+ base_rot_in_world = self.sim.data.get_body_xmat(self.robot_model.root_body).reshape((3, 3))
+ base_pose_in_world = T.make_pose(base_pos_in_world, base_rot_in_world)
+ world_pose_in_base = T.pose_inv(base_pose_in_world)
+
+ pose_in_base = T.pose_in_A_to_pose_in_B(pose_in_world, world_pose_in_base)
+ return pose_in_base
+
+ def set_robot_joint_positions(self, jpos):
+ """
+ Helper method to force robot joint positions to the passed values.
+
+ Args:
+ jpos (np.array): Joint positions to manually set the robot to
+ """
+ self.sim.data.qpos[self._ref_joint_pos_indexes] = jpos
+ self.sim.forward()
+
+ @property
+ def js_energy(self):
+ """
+ Returns:
+ np.array: the energy consumed by each joint between previous and current steps
+ """
+ # We assume in the motors torque is proportional to current (and voltage is constant)
+ # In that case the amount of power scales proportional to the torque and the energy is the
+ # time integral of that
+ # Note that we use mean torque
+ return np.abs((1.0 / self.control_freq) * self.recent_torques.average)
+
+ @property
+ def _joint_positions(self):
+ """
+ Returns:
+ np.array: joint positions (in angles / radians)
+ """
+ return self.sim.data.qpos[self._ref_joint_pos_indexes]
+
+ @property
+ def _joint_velocities(self):
+ """
+ Returns:
+ np.array: joint velocities (angular velocity)
+ """
+ return self.sim.data.qvel[self._ref_joint_vel_indexes]
+
+ @property
+ def joint_indexes(self):
+ """
+ Returns:
+ list: mujoco internal indexes for the robot joints
+ """
+ return self._ref_joint_indexes
+
+ def get_sensor_measurement(self, sensor_name):
+ """
+ Grabs relevant sensor data from the sim object
+
+ Args:
+ sensor_name (str): name of the sensor
+
+ Returns:
+ np.array: sensor values
+ """
+ sensor_idx = np.sum(self.sim.model.sensor_dim[: self.sim.model.sensor_name2id(sensor_name)])
+ sensor_dim = self.sim.model.sensor_dim[self.sim.model.sensor_name2id(sensor_name)]
+ return np.array(self.sim.data.sensordata[sensor_idx : sensor_idx + sensor_dim])
diff --git a/phantom/submodules/phantom-robosuite/robosuite/robots/single_arm.py b/phantom/submodules/phantom-robosuite/robosuite/robots/single_arm.py
new file mode 100644
index 0000000000000000000000000000000000000000..934f91728a801faf113f4a5f1eacdfd4868c66dd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/robots/single_arm.py
@@ -0,0 +1,463 @@
+import copy
+import os
+from collections import OrderedDict
+
+import numpy as np
+
+import robosuite.utils.transform_utils as T
+from robosuite.controllers import controller_factory, load_controller_config
+from robosuite.models.grippers import gripper_factory
+from robosuite.robots.manipulator import Manipulator
+from robosuite.utils.buffers import DeltaBuffer, RingBuffer
+from robosuite.utils.observables import Observable, sensor
+
+
+class SingleArm(Manipulator):
+ """
+ Initializes a single-armed robot simulation object.
+
+ Args:
+ robot_type (str): Specification for specific robot arm to be instantiated within this env (e.g: "Panda")
+
+ idn (int or str): Unique ID of this robot. Should be different from others
+
+ controller_config (dict): If set, contains relevant controller parameters for creating a custom controller.
+ Else, uses the default controller for this specific task
+
+ initial_qpos (sequence of float): If set, determines the initial joint positions of the robot to be
+ instantiated for the task
+
+ initialization_noise (dict): Dict containing the initialization noise parameters. The expected keys and
+ corresponding value types are specified below:
+
+ :`'magnitude'`: The scale factor of uni-variate random noise applied to each of a robot's given initial
+ joint positions. Setting this value to "None" or 0.0 results in no noise being applied.
+ If "gaussian" type of noise is applied then this magnitude scales the standard deviation applied,
+ If "uniform" type of noise is applied then this magnitude sets the bounds of the sampling range
+ :`'type'`: Type of noise to apply. Can either specify "gaussian" or "uniform"
+
+ :Note: Specifying None will automatically create the required dict with "magnitude" set to 0.0
+
+ mount_type (str): type of mount, used to instantiate mount models from mount factory.
+ Default is "default", which is the default mount associated with this robot's corresponding model.
+ None results in no mount, and any other (valid) model overrides the default mount.
+
+ gripper_type (str): type of gripper, used to instantiate
+ gripper models from gripper factory. Default is "default", which is the default gripper associated
+ within the 'robot' specification. None removes the gripper, and any other (valid) model overrides the
+ default gripper
+
+ control_freq (float): how many control signals to receive
+ in every second. This sets the amount of simulation time
+ that passes between every action input.
+ """
+
+ def __init__(
+ self,
+ robot_type: str,
+ idn=0,
+ controller_config=None,
+ initial_qpos=None,
+ initialization_noise=None,
+ mount_type="default",
+ gripper_type="default",
+ control_freq=20,
+ direct_gripper_control=False,
+ ):
+
+ self.controller = None
+ self.controller_config = copy.deepcopy(controller_config)
+ self.gripper_type = gripper_type
+ self.has_gripper = self.gripper_type is not None
+
+ self.gripper = None # Gripper class
+ self.gripper_joints = None # xml joint names for gripper
+ self._ref_gripper_joint_pos_indexes = None # xml gripper joint position indexes in mjsim
+ self._ref_gripper_joint_vel_indexes = None # xml gripper joint velocity indexes in mjsim
+ self._ref_joint_gripper_actuator_indexes = None # xml gripper (pos) actuator indexes for robot in mjsim
+ self.eef_rot_offset = None # rotation offsets from final arm link to gripper (quat)
+ self.eef_site_id = None # xml element id for eef in mjsim
+ self.eef_cylinder_id = None # xml element id for eef cylinder in mjsim
+ self.torques = None # Current torques being applied
+
+ self.recent_ee_forcetorques = None # Current and last forces / torques sensed at eef
+ self.recent_ee_pose = None # Current and last eef pose (pos + ori (quat))
+ self.recent_ee_vel = None # Current and last eef velocity
+ self.recent_ee_vel_buffer = None # RingBuffer holding prior 10 values of velocity values
+ self.recent_ee_acc = None # Current and last eef acceleration
+
+ self.direct_gripper_control = direct_gripper_control
+
+ super().__init__(
+ robot_type=robot_type,
+ idn=idn,
+ initial_qpos=initial_qpos,
+ initialization_noise=initialization_noise,
+ mount_type=mount_type,
+ control_freq=control_freq,
+ )
+
+ def _load_controller(self):
+ """
+ Loads controller to be used for dynamic trajectories
+ """
+ # First, load the default controller if none is specified
+ if not self.controller_config:
+ # Need to update default for a single agent
+ controller_path = os.path.join(
+ os.path.dirname(__file__),
+ "..",
+ "controllers/config/{}.json".format(self.robot_model.default_controller_config),
+ )
+ self.controller_config = load_controller_config(custom_fpath=controller_path)
+
+ # Assert that the controller config is a dict file:
+ # NOTE: "type" must be one of: {JOINT_POSITION, JOINT_TORQUE, JOINT_VELOCITY,
+ # OSC_POSITION, OSC_POSE, IK_POSE}
+ assert (
+ type(self.controller_config) == dict
+ ), "Inputted controller config must be a dict! Instead, got type: {}".format(type(self.controller_config))
+
+ # Add to the controller dict additional relevant params:
+ # the robot name, mujoco sim, eef_name, joint_indexes, timestep (model) freq,
+ # policy (control) freq, and ndim (# joints)
+ self.controller_config["robot_name"] = self.name
+ self.controller_config["sim"] = self.sim
+ self.controller_config["eef_name"] = self.gripper.important_sites["grip_site"]
+ self.controller_config["eef_rot_offset"] = self.eef_rot_offset
+ self.controller_config["joint_indexes"] = {
+ "joints": self.joint_indexes,
+ "qpos": self._ref_joint_pos_indexes,
+ "qvel": self._ref_joint_vel_indexes,
+ }
+ self.controller_config["actuator_range"] = self.torque_limits
+ self.controller_config["policy_freq"] = self.control_freq
+ self.controller_config["ndim"] = len(self.robot_joints)
+
+ # Instantiate the relevant controller
+ self.controller = controller_factory(self.controller_config["type"], self.controller_config)
+
+ def load_model(self):
+ """
+ Loads robot and optionally add grippers.
+ """
+ # First, run the superclass method to load the relevant model
+ super().load_model()
+
+ # Verify that the loaded model is of the correct type for this robot
+ if self.robot_model.arm_type != "single":
+ raise TypeError(
+ "Error loading robot model: Incompatible arm type specified for this robot. "
+ "Requested model arm type: {}, robot arm type: {}".format(self.robot_model.arm_type, type(self))
+ )
+
+ # Now, load the gripper if necessary
+ if self.has_gripper:
+ if self.gripper_type == "default":
+ # Load the default gripper from the robot file
+ self.gripper = gripper_factory(self.robot_model.default_gripper, idn=self.idn)
+ else:
+ # Load user-specified gripper
+ self.gripper = gripper_factory(self.gripper_type, idn=self.idn)
+ else:
+ # Load null gripper
+ self.gripper = gripper_factory(None, idn=self.idn)
+ # Grab eef rotation offset
+ self.eef_rot_offset = T.quat_multiply(self.robot_model.hand_rotation_offset, self.gripper.rotation_offset)
+ # Add gripper to this robot model
+ self.robot_model.add_gripper(self.gripper)
+
+ def reset(self, deterministic=False):
+ """
+ Sets initial pose of arm and grippers. Overrides gripper joint configuration if we're using a
+ deterministic reset (e.g.: hard reset from xml file)
+
+ Args:
+ deterministic (bool): If true, will not randomize initializations within the sim
+ """
+ # First, run the superclass method to reset the position and controller
+ super().reset(deterministic)
+
+ # Now, reset the gripper if necessary
+ if self.has_gripper:
+ if not deterministic:
+ self.sim.data.qpos[self._ref_gripper_joint_pos_indexes] = self.gripper.init_qpos
+
+ self.gripper.current_action = np.zeros(self.gripper.dof)
+
+ # Update base pos / ori references in controller
+ self.controller.update_base_pose(self.base_pos, self.base_ori)
+
+ # # Setup buffers to hold recent values
+ self.recent_ee_forcetorques = DeltaBuffer(dim=6)
+ self.recent_ee_pose = DeltaBuffer(dim=7)
+ self.recent_ee_vel = DeltaBuffer(dim=6)
+ self.recent_ee_vel_buffer = RingBuffer(dim=6, length=10)
+ self.recent_ee_acc = DeltaBuffer(dim=6)
+
+ def setup_references(self):
+ """
+ Sets up necessary reference for robots, grippers, and objects.
+
+ Note that this should get called during every reset from the environment
+ """
+ # First, run the superclass method to setup references for joint-related values / indexes
+ super().setup_references()
+
+ # Now, add references to gripper if necessary
+ # indices for grippers in qpos, qvel
+ if self.has_gripper:
+ self.gripper_joints = list(self.gripper.joints)
+ self._ref_gripper_joint_pos_indexes = [self.sim.model.get_joint_qpos_addr(x) for x in self.gripper_joints]
+ self._ref_gripper_joint_vel_indexes = [self.sim.model.get_joint_qvel_addr(x) for x in self.gripper_joints]
+ self._ref_joint_gripper_actuator_indexes = [
+ self.sim.model.actuator_name2id(actuator) for actuator in self.gripper.actuators
+ ]
+
+ # IDs of sites for eef visualization
+ self.eef_site_id = self.sim.model.site_name2id(self.gripper.important_sites["grip_site"])
+ self.eef_cylinder_id = self.sim.model.site_name2id(self.gripper.important_sites["grip_cylinder"])
+
+ def control(self, action, policy_step=False):
+ """
+ Actuate the robot with the
+ passed joint velocities and gripper control.
+
+ Args:
+ action (np.array): The control to apply to the robot. The first @self.robot_model.dof dimensions should be
+ the desired normalized joint velocities and if the robot has a gripper, the next @self.gripper.dof
+ dimensions should be actuation controls for the gripper.
+ policy_step (bool): Whether a new policy step (action) is being taken
+
+ Raises:
+ AssertionError: [Invalid action dimension]
+ """
+
+ # clip actions into valid range
+ assert len(action) == self.action_dim, "environment got invalid action dimension -- expected {}, got {}".format(
+ self.action_dim, len(action)
+ )
+
+ gripper_action = None
+ if self.has_gripper:
+ gripper_action = action[self.controller.control_dim :] # all indexes past controller dimension indexes
+ arm_action = action[: self.controller.control_dim]
+ else:
+ arm_action = action
+
+ # Update the controller goal if this is a new policy step
+ if policy_step:
+ self.controller.set_goal(arm_action)
+
+ # Now run the controller for a step
+ torques = self.controller.run_controller()
+
+ # Clip the torques
+ low, high = self.torque_limits
+ self.torques = np.clip(torques, low, high)
+
+ # Get gripper action, if applicable
+ if self.has_gripper:
+ self.grip_action(gripper=self.gripper, gripper_action=gripper_action)
+
+ # Apply joint torque control
+ self.sim.data.ctrl[self._ref_joint_actuator_indexes] = self.torques
+
+ # If this is a policy step, also update buffers holding recent values of interest
+ if policy_step:
+ # Update proprioceptive values
+ self.recent_qpos.push(self._joint_positions)
+ self.recent_actions.push(action)
+ self.recent_torques.push(self.torques)
+ self.recent_ee_forcetorques.push(np.concatenate((self.ee_force, self.ee_torque)))
+ self.recent_ee_pose.push(np.concatenate((self.controller.ee_pos, T.mat2quat(self.controller.ee_ori_mat))))
+ self.recent_ee_vel.push(np.concatenate((self.controller.ee_pos_vel, self.controller.ee_ori_vel)))
+
+ # Estimation of eef acceleration (averaged derivative of recent velocities)
+ self.recent_ee_vel_buffer.push(np.concatenate((self.controller.ee_pos_vel, self.controller.ee_ori_vel)))
+ diffs = np.vstack(
+ [self.recent_ee_acc.current, self.control_freq * np.diff(self.recent_ee_vel_buffer.buf, axis=0)]
+ )
+ ee_acc = np.array([np.convolve(col, np.ones(10) / 10.0, mode="valid")[0] for col in diffs.transpose()])
+ self.recent_ee_acc.push(ee_acc)
+
+ def _visualize_grippers(self, visible):
+ """
+ Visualizes the gripper site(s) if applicable.
+
+ Args:
+ visible (bool): True if visualizing the gripper for this arm.
+ """
+ self.gripper.set_sites_visibility(sim=self.sim, visible=visible)
+
+ def setup_observables(self):
+ """
+ Sets up observables to be used for this robot
+
+ Returns:
+ OrderedDict: Dictionary mapping observable names to its corresponding Observable object
+ """
+ # Get general robot observables first
+ observables = super().setup_observables()
+
+ # Get prefix from robot model to avoid naming clashes for multiple robots and define observables modality
+ pf = self.robot_model.naming_prefix
+ modality = f"{pf}proprio"
+
+ # eef features
+ @sensor(modality=modality)
+ def eef_pos(obs_cache):
+ return np.array(self.sim.data.site_xpos[self.eef_site_id])
+
+ @sensor(modality=modality)
+ def eef_quat(obs_cache):
+ return T.convert_quat(self.sim.data.get_body_xquat(self.robot_model.eef_name), to="xyzw")
+
+ @sensor(modality=modality)
+ def eef_vel_lin(obs_cache):
+ return np.array(self.sim.data.get_body_xvelp(self.robot_model.eef_name))
+
+ @sensor(modality=modality)
+ def eef_vel_ang(obs_cache):
+ return np.array(self.sim.data.get_body_xvelr(self.robot_model.eef_name))
+
+ sensors = [eef_pos, eef_quat, eef_vel_lin, eef_vel_ang]
+ names = [f"{pf}eef_pos", f"{pf}eef_quat", f"{pf}eef_vel_lin", f"{pf}eef_vel_ang"]
+ # Exclude eef vel by default
+ actives = [True, True, False, False]
+
+ # add in gripper sensors if this robot has a gripper
+ if self.has_gripper:
+
+ @sensor(modality=modality)
+ def gripper_qpos(obs_cache):
+ return np.array([self.sim.data.qpos[x] for x in self._ref_gripper_joint_pos_indexes])
+
+ @sensor(modality=modality)
+ def gripper_qvel(obs_cache):
+ return np.array([self.sim.data.qvel[x] for x in self._ref_gripper_joint_vel_indexes])
+
+ sensors += [gripper_qpos, gripper_qvel]
+ names += [f"{pf}gripper_qpos", f"{pf}gripper_qvel"]
+ actives += [True, True]
+
+ # Create observables for this robot
+ for name, s, active in zip(names, sensors, actives):
+ observables[name] = Observable(
+ name=name,
+ sensor=s,
+ sampling_rate=self.control_freq,
+ active=active,
+ )
+
+ return observables
+
+ @property
+ def action_limits(self):
+ """
+ Action lower/upper limits per dimension.
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum (low) action values
+ - (np.array) maximum (high) action values
+ """
+ # Action limits based on controller limits
+ low, high = ([-1] * self.gripper.dof, [1] * self.gripper.dof) if self.has_gripper else ([], [])
+ low_c, high_c = self.controller.control_limits
+ low = np.concatenate([low_c, low])
+ high = np.concatenate([high_c, high])
+
+ return low, high
+
+ @property
+ def ee_ft_integral(self):
+ """
+ Returns:
+ np.array: the integral over time of the applied ee force-torque
+ """
+ return np.abs((1.0 / self.control_freq) * self.recent_ee_forcetorques.average)
+
+ @property
+ def ee_force(self):
+ """
+ Returns:
+ np.array: force applied at the force sensor at the robot arm's eef
+ """
+ return self.get_sensor_measurement(self.gripper.important_sensors["force_ee"])
+
+ @property
+ def ee_torque(self):
+ """
+ Returns torque applied at the torque sensor at the robot arm's eef
+ """
+ return self.get_sensor_measurement(self.gripper.important_sensors["torque_ee"])
+
+ @property
+ def _hand_pose(self):
+ """
+ Returns:
+ np.array: (4,4) array corresponding to the eef pose in base frame of robot.
+ """
+ return self.pose_in_base_from_name(self.robot_model.eef_name)
+
+ @property
+ def _hand_quat(self):
+ """
+ Returns:
+ np.array: (x,y,z,w) eef quaternion in base frame of robot.
+ """
+ return T.mat2quat(self._hand_orn)
+
+ @property
+ def _hand_total_velocity(self):
+ """
+ Returns:
+ np.array: 6-array representing the total eef velocity (linear + angular) in the base frame
+ """
+
+ # Use jacobian to translate joint velocities to end effector velocities.
+ Jp = self.sim.data.get_body_jacp(self.robot_model.eef_name).reshape((3, -1))
+ Jp_joint = Jp[:, self._ref_joint_vel_indexes]
+
+ Jr = self.sim.data.get_body_jacr(self.robot_model.eef_name).reshape((3, -1))
+ Jr_joint = Jr[:, self._ref_joint_vel_indexes]
+
+ eef_lin_vel = Jp_joint.dot(self._joint_velocities)
+ eef_rot_vel = Jr_joint.dot(self._joint_velocities)
+ return np.concatenate([eef_lin_vel, eef_rot_vel])
+
+ @property
+ def _hand_pos(self):
+ """
+ Returns:
+ np.array: 3-array representing the position of eef in base frame of robot.
+ """
+ eef_pose_in_base = self._hand_pose
+ return eef_pose_in_base[:3, 3]
+
+ @property
+ def _hand_orn(self):
+ """
+ Returns:
+ np.array: (3,3) array representing the orientation of eef in base frame of robot as a rotation matrix.
+ """
+ eef_pose_in_base = self._hand_pose
+ return eef_pose_in_base[:3, :3]
+
+ @property
+ def _hand_vel(self):
+ """
+ Returns:
+ np.array: (x,y,z) velocity of eef in base frame of robot.
+ """
+ return self._hand_total_velocity[:3]
+
+ @property
+ def _hand_ang_vel(self):
+ """
+ Returns:
+ np.array: (ax,ay,az) angular velocity of eef in base frame of robot.
+ """
+ return self._hand_total_velocity[3:]
diff --git a/phantom/submodules/phantom-robosuite/robosuite/scripts/browse_mjcf_model.py b/phantom/submodules/phantom-robosuite/robosuite/scripts/browse_mjcf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..02f87f6edac3181a4278483be99ea279ad510d7b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/scripts/browse_mjcf_model.py
@@ -0,0 +1,35 @@
+"""Visualize MJCF models.
+
+Loads MJCF XML models from file and renders it on screen.
+
+Example:
+ $ python browse_mjcf_model.py --filepath ../models/assets/arenas/table_arena.xml
+"""
+
+import argparse
+import os
+
+import mujoco
+
+import robosuite as suite
+from robosuite.utils import OpenCVRenderer
+from robosuite.utils.binding_utils import MjRenderContext, MjSim
+
+if __name__ == "__main__":
+
+ arena_file = os.path.join(suite.models.assets_root, "arenas/pegs_arena.xml")
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--filepath", type=str, default=arena_file)
+ args = parser.parse_args()
+
+ model = mujoco.MjModel.from_xml_path(args.filepath)
+ sim = MjSim(model)
+ render_context = MjRenderContext(sim)
+ sim.add_render_context(render_context)
+ viewer = OpenCVRenderer(sim)
+
+ print("Press ESC to exit...")
+ while True:
+ sim.step()
+ viewer.render()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/scripts/collect_human_demonstrations.py b/phantom/submodules/phantom-robosuite/robosuite/scripts/collect_human_demonstrations.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce59432f856bac14dffeb6f871be733c9a1d8e4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/scripts/collect_human_demonstrations.py
@@ -0,0 +1,253 @@
+"""
+A script to collect a batch of human demonstrations.
+
+The demonstrations can be played back using the `playback_demonstrations_from_hdf5.py` script.
+"""
+
+import argparse
+import datetime
+import json
+import os
+import shutil
+import time
+from glob import glob
+
+import h5py
+import numpy as np
+
+import robosuite as suite
+import robosuite.macros as macros
+from robosuite import load_controller_config
+from robosuite.utils.input_utils import input2action
+from robosuite.wrappers import DataCollectionWrapper, VisualizationWrapper
+
+
+def collect_human_trajectory(env, device, arm, env_configuration):
+ """
+ Use the device (keyboard or SpaceNav 3D mouse) to collect a demonstration.
+ The rollout trajectory is saved to files in npz format.
+ Modify the DataCollectionWrapper wrapper to add new fields or change data formats.
+
+ Args:
+ env (MujocoEnv): environment to control
+ device (Device): to receive controls from the device
+ arms (str): which arm to control (eg bimanual) 'right' or 'left'
+ env_configuration (str): specified environment configuration
+ """
+
+ env.reset()
+
+ # ID = 2 always corresponds to agentview
+ env.render()
+
+ is_first = True
+
+ task_completion_hold_count = -1 # counter to collect 10 timesteps after reaching goal
+ device.start_control()
+
+ # Loop until we get a reset from the input or the task completes
+ while True:
+ # Set active robot
+ active_robot = env.robots[0] if env_configuration == "bimanual" else env.robots[arm == "left"]
+
+ # Get the newest action
+ action, grasp = input2action(
+ device=device, robot=active_robot, active_arm=arm, env_configuration=env_configuration
+ )
+
+ # If action is none, then this a reset so we should break
+ if action is None:
+ break
+
+ # Run environment step
+ env.step(action)
+ env.render()
+
+ # Also break if we complete the task
+ if task_completion_hold_count == 0:
+ break
+
+ # state machine to check for having a success for 10 consecutive timesteps
+ if env._check_success():
+ if task_completion_hold_count > 0:
+ task_completion_hold_count -= 1 # latched state, decrement count
+ else:
+ task_completion_hold_count = 10 # reset count on first success timestep
+ else:
+ task_completion_hold_count = -1 # null the counter if there's no success
+
+ # cleanup for end of data collection episodes
+ env.close()
+
+
+def gather_demonstrations_as_hdf5(directory, out_dir, env_info):
+ """
+ Gathers the demonstrations saved in @directory into a
+ single hdf5 file.
+
+ The strucure of the hdf5 file is as follows.
+
+ data (group)
+ date (attribute) - date of collection
+ time (attribute) - time of collection
+ repository_version (attribute) - repository version used during collection
+ env (attribute) - environment name on which demos were collected
+
+ demo1 (group) - every demonstration has a group
+ model_file (attribute) - model xml string for demonstration
+ states (dataset) - flattened mujoco states
+ actions (dataset) - actions applied during demonstration
+
+ demo2 (group)
+ ...
+
+ Args:
+ directory (str): Path to the directory containing raw demonstrations.
+ out_dir (str): Path to where to store the hdf5 file.
+ env_info (str): JSON-encoded string containing environment information,
+ including controller and robot info
+ """
+
+ hdf5_path = os.path.join(out_dir, "demo.hdf5")
+ f = h5py.File(hdf5_path, "w")
+
+ # store some metadata in the attributes of one group
+ grp = f.create_group("data")
+
+ num_eps = 0
+ env_name = None # will get populated at some point
+
+ for ep_directory in os.listdir(directory):
+
+ state_paths = os.path.join(directory, ep_directory, "state_*.npz")
+ states = []
+ actions = []
+ success = False
+
+ for state_file in sorted(glob(state_paths)):
+ dic = np.load(state_file, allow_pickle=True)
+ env_name = str(dic["env"])
+
+ states.extend(dic["states"])
+ for ai in dic["action_infos"]:
+ actions.append(ai["actions"])
+ success = success or dic["successful"]
+
+ if len(states) == 0:
+ continue
+
+ # Add only the successful demonstration to dataset
+ if success:
+ print("Demonstration is successful and has been saved")
+ # Delete the last state. This is because when the DataCollector wrapper
+ # recorded the states and actions, the states were recorded AFTER playing that action,
+ # so we end up with an extra state at the end.
+ del states[-1]
+ assert len(states) == len(actions)
+
+ num_eps += 1
+ ep_data_grp = grp.create_group("demo_{}".format(num_eps))
+
+ # store model xml as an attribute
+ xml_path = os.path.join(directory, ep_directory, "model.xml")
+ with open(xml_path, "r") as f:
+ xml_str = f.read()
+ ep_data_grp.attrs["model_file"] = xml_str
+
+ # write datasets for states and actions
+ ep_data_grp.create_dataset("states", data=np.array(states))
+ ep_data_grp.create_dataset("actions", data=np.array(actions))
+ else:
+ print("Demonstration is unsuccessful and has NOT been saved")
+
+ # write dataset attributes (metadata)
+ now = datetime.datetime.now()
+ grp.attrs["date"] = "{}-{}-{}".format(now.month, now.day, now.year)
+ grp.attrs["time"] = "{}:{}:{}".format(now.hour, now.minute, now.second)
+ grp.attrs["repository_version"] = suite.__version__
+ grp.attrs["env"] = env_name
+ grp.attrs["env_info"] = env_info
+
+ f.close()
+
+
+if __name__ == "__main__":
+ # Arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--directory",
+ type=str,
+ default=os.path.join(suite.models.assets_root, "demonstrations"),
+ )
+ parser.add_argument("--environment", type=str, default="Lift")
+ parser.add_argument("--robots", nargs="+", type=str, default="Panda", help="Which robot(s) to use in the env")
+ parser.add_argument(
+ "--config", type=str, default="single-arm-opposed", help="Specified environment configuration if necessary"
+ )
+ parser.add_argument("--arm", type=str, default="right", help="Which arm to control (eg bimanual) 'right' or 'left'")
+ parser.add_argument("--camera", type=str, default="agentview", help="Which camera to use for collecting demos")
+ parser.add_argument(
+ "--controller", type=str, default="OSC_POSE", help="Choice of controller. Can be 'IK_POSE' or 'OSC_POSE'"
+ )
+ parser.add_argument("--device", type=str, default="keyboard")
+ parser.add_argument("--pos-sensitivity", type=float, default=1.0, help="How much to scale position user inputs")
+ parser.add_argument("--rot-sensitivity", type=float, default=1.0, help="How much to scale rotation user inputs")
+ args = parser.parse_args()
+
+ # Get controller config
+ controller_config = load_controller_config(default_controller=args.controller)
+
+ # Create argument configuration
+ config = {
+ "env_name": args.environment,
+ "robots": args.robots,
+ "controller_configs": controller_config,
+ }
+
+ # Check if we're using a multi-armed environment and use env_configuration argument if so
+ if "TwoArm" in args.environment:
+ config["env_configuration"] = args.config
+
+ # Create environment
+ env = suite.make(
+ **config,
+ has_renderer=True,
+ has_offscreen_renderer=False,
+ render_camera=args.camera,
+ ignore_done=True,
+ use_camera_obs=False,
+ reward_shaping=True,
+ control_freq=20,
+ )
+
+ # Wrap this with visualization wrapper
+ env = VisualizationWrapper(env)
+
+ # Grab reference to controller config and convert it to json-encoded string
+ env_info = json.dumps(config)
+
+ # wrap the environment with data collection wrapper
+ tmp_directory = "/tmp/{}".format(str(time.time()).replace(".", "_"))
+ env = DataCollectionWrapper(env, tmp_directory)
+
+ # initialize device
+ if args.device == "keyboard":
+ from robosuite.devices import Keyboard
+
+ device = Keyboard(pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity)
+ elif args.device == "spacemouse":
+ from robosuite.devices import SpaceMouse
+
+ device = SpaceMouse(pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity)
+ else:
+ raise Exception("Invalid device choice: choose either 'keyboard' or 'spacemouse'.")
+
+ # make a new timestamped directory
+ t1, t2 = str(time.time()).split(".")
+ new_dir = os.path.join(args.directory, "{}_{}".format(t1, t2))
+ os.makedirs(new_dir)
+
+ # collect demonstrations
+ while True:
+ collect_human_trajectory(env, device, args.arm, args.config)
+ gather_demonstrations_as_hdf5(tmp_directory, new_dir, env_info)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/scripts/compile_mjcf_model.py b/phantom/submodules/phantom-robosuite/robosuite/scripts/compile_mjcf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5b9334a00d4b3fdfa5a2ab8fa06eb4013faf5cc
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/scripts/compile_mjcf_model.py
@@ -0,0 +1,39 @@
+"""Loads a raw mjcf file and saves a compiled mjcf file.
+
+This avoids mujoco-py from complaining about .urdf extension.
+Also allows assets to be compiled properly.
+
+Example:
+ $ python compile_mjcf_model.py source_mjcf.xml target_mjcf.xml
+"""
+
+import os
+import sys
+from shutil import copyfile
+
+import mujoco
+
+
+def print_usage():
+ print("""python compile_mjcf_model.py input_file output_file""")
+
+
+if __name__ == "__main__":
+
+ if len(sys.argv) != 3:
+ print_usage()
+ exit(0)
+
+ input_file = sys.argv[1]
+ output_file = sys.argv[2]
+ input_folder = os.path.dirname(input_file)
+
+ tempfile = os.path.join(input_folder, ".robosuite_temp_model.xml")
+ copyfile(input_file, tempfile)
+
+ model = mujoco.MjModel.from_xml_path(tempfile)
+ xml_string = model.get_xml()
+ with open(output_file, "w") as f:
+ f.write(xml_string)
+
+ os.remove(tempfile)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/scripts/make_reset_video.py b/phantom/submodules/phantom-robosuite/robosuite/scripts/make_reset_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8eb7ef70a2721da4ce20b7c6a8f797d6b161eef
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/scripts/make_reset_video.py
@@ -0,0 +1,97 @@
+"""
+Convenience script to make a video out of initial environment
+configurations. This can be a useful debugging tool to understand
+what different sampled environment configurations look like.
+"""
+
+import argparse
+
+import imageio
+import numpy as np
+
+import robosuite as suite
+from robosuite.controllers import load_controller_config
+from robosuite.utils.input_utils import *
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # camera to use for generating frames
+ parser.add_argument(
+ "--camera",
+ type=str,
+ default="agentview",
+ )
+
+ # number of frames in output video
+ parser.add_argument(
+ "--frames",
+ type=int,
+ default=10,
+ )
+
+ # path to output video
+ parser.add_argument(
+ "--output",
+ type=str,
+ default="reset.mp4",
+ )
+
+ args = parser.parse_args()
+ camera_name = args.camera
+ num_frames = args.frames
+ output_path = args.output
+
+ # Create dict to hold options that will be passed to env creation call
+ options = {}
+
+ # print welcome info
+ print("Welcome to robosuite v{}!".format(suite.__version__))
+ print(suite.__logo__)
+
+ # Choose environment and add it to options
+ options["env_name"] = choose_environment()
+
+ # If a multi-arm environment has been chosen, choose configuration and appropriate robot(s)
+ if "TwoArm" in options["env_name"]:
+ # Choose env config and add it to options
+ options["env_configuration"] = choose_multi_arm_config()
+
+ # If chosen configuration was bimanual, the corresponding robot must be Baxter. Else, have user choose robots
+ if options["env_configuration"] == "bimanual":
+ options["robots"] = "Baxter"
+ else:
+ options["robots"] = []
+
+ # Have user choose two robots
+ print("A multiple single-arm configuration was chosen.\n")
+
+ for i in range(2):
+ print("Please choose Robot {}...\n".format(i))
+ options["robots"].append(choose_robots(exclude_bimanual=True))
+
+ # Else, we simply choose a single (single-armed) robot to instantiate in the environment
+ else:
+ options["robots"] = choose_robots(exclude_bimanual=True)
+
+ # Load the controller
+ options["controller_configs"] = load_controller_config(default_controller="OSC_POSE")
+
+ # initialize the task
+ env = suite.make(
+ **options,
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ ignore_done=True,
+ use_camera_obs=False,
+ control_freq=20,
+ )
+
+ # write a video
+ video_writer = imageio.get_writer(output_path, fps=5)
+ for i in range(num_frames):
+ env.reset()
+ video_img = env.sim.render(height=512, width=512, camera_name=camera_name)[::-1]
+ env.step(np.zeros_like(env.action_spec[0]))
+ video_writer.append_data(video_img)
+ video_writer.close()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/scripts/playback_demonstrations_from_hdf5.py b/phantom/submodules/phantom-robosuite/robosuite/scripts/playback_demonstrations_from_hdf5.py
new file mode 100644
index 0000000000000000000000000000000000000000..0decbd1b6edd31c3d31de4f52b22327d6fd69a64
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/scripts/playback_demonstrations_from_hdf5.py
@@ -0,0 +1,106 @@
+"""
+A convenience script to playback random demonstrations from
+a set of demonstrations stored in a hdf5 file.
+
+Arguments:
+ --folder (str): Path to demonstrations
+ --use-actions (optional): If this flag is provided, the actions are played back
+ through the MuJoCo simulator, instead of loading the simulator states
+ one by one.
+ --visualize-gripper (optional): If set, will visualize the gripper site
+
+Example:
+ $ python playback_demonstrations_from_hdf5.py --folder ../models/assets/demonstrations/lift/
+"""
+
+import argparse
+import json
+import os
+import random
+
+import h5py
+import numpy as np
+
+import robosuite
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--folder",
+ type=str,
+ help="Path to your demonstration folder that contains the demo.hdf5 file, e.g.: "
+ "'path_to_assets_dir/demonstrations/YOUR_DEMONSTRATION'",
+ ),
+ parser.add_argument(
+ "--use-actions",
+ action="store_true",
+ )
+ args = parser.parse_args()
+
+ demo_path = args.folder
+ hdf5_path = os.path.join(demo_path, "demo.hdf5")
+ f = h5py.File(hdf5_path, "r")
+ env_name = f["data"].attrs["env"]
+ env_info = json.loads(f["data"].attrs["env_info"])
+
+ env = robosuite.make(
+ **env_info,
+ has_renderer=True,
+ has_offscreen_renderer=False,
+ ignore_done=True,
+ use_camera_obs=False,
+ reward_shaping=True,
+ control_freq=20,
+ )
+
+ # list of all demonstrations episodes
+ demos = list(f["data"].keys())
+
+ while True:
+ print("Playing back random episode... (press ESC to quit)")
+
+ # select an episode randomly
+ ep = random.choice(demos)
+
+ # read the model xml, using the metadata stored in the attribute for this episode
+ model_xml = f["data/{}".format(ep)].attrs["model_file"]
+
+ env.reset()
+ xml = env.edit_model_xml(model_xml)
+ env.reset_from_xml_string(xml)
+ env.sim.reset()
+ env.viewer.set_camera(0)
+
+ # load the flattened mujoco states
+ states = f["data/{}/states".format(ep)][()]
+
+ if args.use_actions:
+
+ # load the initial state
+ env.sim.set_state_from_flattened(states[0])
+ env.sim.forward()
+
+ # load the actions and play them back open-loop
+ actions = np.array(f["data/{}/actions".format(ep)][()])
+ num_actions = actions.shape[0]
+
+ for j, action in enumerate(actions):
+ env.step(action)
+ env.render()
+
+ if j < num_actions - 1:
+ # ensure that the actions deterministically lead to the same recorded states
+ state_playback = env.sim.get_state().flatten()
+ if not np.all(np.equal(states[j + 1], state_playback)):
+ err = np.linalg.norm(states[j + 1] - state_playback)
+ print(f"[warning] playback diverged by {err:.2f} for ep {ep} at step {j}")
+
+ else:
+
+ # force the sequence of internal mujoco states one by one
+ for state in states:
+ env.sim.set_state_from_flattened(state)
+ env.sim.forward()
+ env.render()
+
+ f.close()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/scripts/setup_macros.py b/phantom/submodules/phantom-robosuite/robosuite/scripts/setup_macros.py
new file mode 100644
index 0000000000000000000000000000000000000000..16abdde5ad8246d018e387890e2c32539602e1b3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/scripts/setup_macros.py
@@ -0,0 +1,31 @@
+"""
+This script sets up a private macros file.
+The private macros file (macros_private.py) is not tracked by git,
+allowing user-specific settings that are not tracked by git.
+This script checks if macros_private.py exists.
+If applicable, it creates the private macros at robosuite/macros_private.py
+"""
+
+import os
+import shutil
+
+import robosuite
+
+if __name__ == "__main__":
+ base_path = robosuite.__path__[0]
+ macros_path = os.path.join(base_path, "macros.py")
+ macros_private_path = os.path.join(base_path, "macros_private.py")
+
+ if not os.path.exists(macros_path):
+ print("{} does not exist! Aborting...".format(macros_path))
+
+ if os.path.exists(macros_private_path):
+ ans = input("{} already exists! \noverwrite? (y/n)\n".format(macros_private_path))
+
+ if ans == "y":
+ print("REMOVING")
+ else:
+ exit()
+
+ shutil.copyfile(macros_path, macros_private_path)
+ print("copied {}\nto {}".format(macros_path, macros_private_path))
diff --git a/phantom/submodules/phantom-robosuite/robosuite/scripts/tune_camera.py b/phantom/submodules/phantom-robosuite/robosuite/scripts/tune_camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..e27e380b743fd158298157d7f9046a419c36f93f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/scripts/tune_camera.py
@@ -0,0 +1,226 @@
+"""
+Convenience script to tune a camera view in a mujoco environment.
+Allows keyboard presses to move a camera around in the viewer, and
+then prints the final position and quaternion you should set
+for your camera in the mujoco XML file.
+"""
+
+import argparse
+import time
+import xml.etree.ElementTree as ET
+
+import numpy as np
+from pynput.keyboard import Controller, Key, Listener
+
+import robosuite
+import robosuite.utils.transform_utils as T
+from robosuite.utils.camera_utils import CameraMover
+from robosuite.utils.mjcf_utils import find_elements, find_parent
+
+# some settings
+DELTA_POS_KEY_PRESS = 0.05 # delta camera position per key press
+DELTA_ROT_KEY_PRESS = 1 # delta camera angle per key press
+
+
+class KeyboardHandler:
+ def __init__(self, camera_mover):
+ """
+ Store internal state here.
+
+ Args:
+ camera_mover (CameraMover): Playback camera class
+ cam_body_id (int): id corresponding to parent body of camera element
+ """
+ self.camera_mover = camera_mover
+
+ # make a thread to listen to keyboard and register our callback functions
+ self.listener = Listener(on_press=self.on_press, on_release=self.on_release)
+
+ # start listening
+ self.listener.start()
+
+ def on_press(self, key):
+ """
+ Key handler for key presses.
+
+ Args:
+ key (int): keycode corresponding to the key that was pressed
+ """
+
+ try:
+ # controls for moving rotation
+ if key == Key.up:
+ # rotate up
+ self.camera_mover.rotate_camera(point=None, axis=[1.0, 0.0, 0.0], angle=DELTA_ROT_KEY_PRESS)
+ elif key == Key.down:
+ # rotate down
+ self.camera_mover.rotate_camera(point=None, axis=[-1.0, 0.0, 0.0], angle=DELTA_ROT_KEY_PRESS)
+ elif key == Key.left:
+ # rotate left
+ self.camera_mover.rotate_camera(point=None, axis=[0.0, 1.0, 0.0], angle=DELTA_ROT_KEY_PRESS)
+ elif key == Key.right:
+ # rotate right
+ self.camera_mover.rotate_camera(point=None, axis=[0.0, -1.0, 0.0], angle=DELTA_ROT_KEY_PRESS)
+
+ # controls for moving position
+ elif key.char == "w":
+ # move forward
+ self.camera_mover.move_camera(direction=[0.0, 0.0, -1.0], scale=DELTA_POS_KEY_PRESS)
+ elif key.char == "s":
+ # move backward
+ self.camera_mover.move_camera(direction=[0.0, 0.0, 1.0], scale=DELTA_POS_KEY_PRESS)
+ elif key.char == "a":
+ # move left
+ self.camera_mover.move_camera(direction=[-1.0, 0.0, 0.0], scale=DELTA_POS_KEY_PRESS)
+ elif key.char == "d":
+ # move right
+ self.camera_mover.move_camera(direction=[1.0, 0.0, 0.0], scale=DELTA_POS_KEY_PRESS)
+ elif key.char == "r":
+ # move up
+ self.camera_mover.move_camera(direction=[0.0, 1.0, 0.0], scale=DELTA_POS_KEY_PRESS)
+ elif key.char == "f":
+ # move down
+ self.camera_mover.move_camera(direction=[0.0, -1.0, 0.0], scale=DELTA_POS_KEY_PRESS)
+ elif key.char == ".":
+ # rotate counterclockwise
+ self.camera_mover.rotate_camera(point=None, axis=[0.0, 0.0, 1.0], angle=DELTA_ROT_KEY_PRESS)
+ elif key.char == "/":
+ # rotate clockwise
+ self.camera_mover.rotate_camera(point=None, axis=[0.0, 0.0, -1.0], angle=DELTA_ROT_KEY_PRESS)
+
+ except AttributeError as e:
+ pass
+
+ def on_release(self, key):
+ """
+ Key handler for key releases.
+
+ Args:
+ key: [NOT USED]
+ """
+ pass
+
+
+def print_command(char, info):
+ """
+ Prints out the command + relevant info entered by user
+
+ Args:
+ char (str): Command entered
+ info (str): Any additional info to print
+ """
+ char += " " * (10 - len(char))
+ print("{}\t{}".format(char, info))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--env", type=str, default="Lift")
+ parser.add_argument("--robots", nargs="+", type=str, default="Sawyer", help="Which robot(s) to use in the env")
+ args = parser.parse_args()
+
+ print("\nWelcome to the camera tuning script! You will be able to tune a camera view")
+ print("by moving it around using your keyboard. The controls are printed below.")
+
+ print("")
+ print_command("Keys", "Command")
+ print_command("w-s", "zoom the camera in/out")
+ print_command("a-d", "pan the camera left/right")
+ print_command("r-f", "pan the camera up/down")
+ print_command("arrow keys", "rotate the camera to change view direction")
+ print_command(".-/", "rotate the camera view without changing view direction")
+ print("")
+
+ # read camera XML tag from user input
+ inp = input(
+ "\nPlease paste a camera name below \n"
+ "OR xml tag below (e.g. ) \n"
+ "OR leave blank for an example:\n"
+ )
+
+ if len(inp) == 0:
+ if args.env != "Lift":
+ raise Exception("ERROR: env must be Lift to run default example.")
+ print("\nUsing an example tag corresponding to the frontview camera.")
+ print("This xml tag was copied from robosuite/models/assets/arenas/table_arena.xml")
+ inp = ''
+
+ # remember the tag and infer some properties
+ from_tag = "<" in inp
+ notify_str = (
+ "NOTE: using the following xml tag:\n"
+ if from_tag
+ else "NOTE: using the following camera (initialized at default sim location)\n"
+ )
+
+ print(notify_str)
+ print("{}\n".format(inp))
+
+ cam_tree = ET.fromstring(inp) if from_tag else ET.Element("camera", attrib={"name": inp})
+ CAMERA_NAME = cam_tree.get("name")
+
+ # make the environment
+ env = robosuite.make(
+ args.env,
+ robots=args.robots,
+ has_renderer=True,
+ has_offscreen_renderer=False,
+ ignore_done=True,
+ use_camera_obs=False,
+ control_freq=100,
+ )
+ env.reset()
+
+ # Create the camera mover
+ camera_mover = CameraMover(
+ env=env,
+ camera=CAMERA_NAME,
+ )
+
+ # Make sure we're using the camera that we're modifying
+ camera_id = env.sim.model.camera_name2id(CAMERA_NAME)
+ env.viewer.set_camera(camera_id=camera_id)
+
+ # Infer initial camera pose
+ if from_tag:
+ initial_file_camera_pos = np.array(cam_tree.get("pos").split(" ")).astype(float)
+ initial_file_camera_quat = T.convert_quat(np.array(cam_tree.get("quat").split(" ")).astype(float), to="xyzw")
+ # Set these values as well
+ camera_mover.set_camera_pose(pos=initial_file_camera_pos, quat=initial_file_camera_quat)
+ # Optionally set fov if specified
+ cam_fov = cam_tree.get("fovy", None)
+ if cam_fov is not None:
+ env.sim.model.cam_fovy[camera_id] = float(cam_fov)
+ else:
+ initial_file_camera_pos, initial_file_camera_quat = camera_mover.get_camera_pose()
+ # Define initial file camera pose
+ initial_file_camera_pose = T.make_pose(initial_file_camera_pos, T.quat2mat(initial_file_camera_quat))
+
+ # remember difference between camera pose in initial tag and absolute camera pose in world
+ initial_world_camera_pos, initial_world_camera_quat = camera_mover.get_camera_pose()
+ initial_world_camera_pose = T.make_pose(initial_world_camera_pos, T.quat2mat(initial_world_camera_quat))
+ world_in_file = initial_file_camera_pose.dot(T.pose_inv(initial_world_camera_pose))
+
+ # register callbacks to handle key presses in the viewer
+ key_handler = KeyboardHandler(camera_mover=camera_mover)
+
+ # just spin to let user interact with window
+ spin_count = 0
+ while True:
+ action = np.zeros(env.action_dim)
+ obs, reward, done, _ = env.step(action)
+ env.render()
+ spin_count += 1
+ if spin_count % 500 == 0:
+ # convert from world coordinates to file coordinates (xml subtree)
+ camera_pos, camera_quat = camera_mover.get_camera_pose()
+ world_camera_pose = T.make_pose(camera_pos, T.quat2mat(camera_quat))
+ file_camera_pose = world_in_file.dot(world_camera_pose)
+ # TODO: Figure out why numba causes black screen of death (specifically, during mat2pose --> mat2quat call below)
+ camera_pos, camera_quat = T.mat2pose(file_camera_pose)
+ camera_quat = T.convert_quat(camera_quat, to="wxyz")
+
+ print("\n\ncurrent camera tag you should copy")
+ cam_tree.set("pos", "{} {} {}".format(camera_pos[0], camera_pos[1], camera_pos[2]))
+ cam_tree.set("quat", "{} {} {} {}".format(camera_quat[0], camera_quat[1], camera_quat[2], camera_quat[3]))
+ print(ET.tostring(cam_tree, encoding="utf8").decode("utf8"))
diff --git a/phantom/submodules/phantom-robosuite/robosuite/scripts/tune_joints.py b/phantom/submodules/phantom-robosuite/robosuite/scripts/tune_joints.py
new file mode 100644
index 0000000000000000000000000000000000000000..09dedbd5fbe308fa523c8afeb2479347972504d3
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/scripts/tune_joints.py
@@ -0,0 +1,311 @@
+"""
+Convenience script to tune a robot's joint positions in a mujoco environment.
+Allows keyboard presses to move specific robot joints around in the viewer, and
+then prints the current joint parameters upon an inputted command
+
+RELEVANT KEY PRESSES:
+ '1 - n' : Sets the active robot joint being tuned to this number. Maximum
+ is n which is the number of robot joints
+ 't' : Toggle between robot arms being tuned (only applicable for multi-arm environments)
+ 'r' : Resets the active joint values to 0
+ 'UP_ARROW' : Increment the active robot joint position
+ 'DOWN_ARROW' : Decrement the active robot joint position
+ 'RIGHT_ARROW' : Increment the delta joint position change per keypress
+ 'LEFT_ARROW' : Decrement the delta joint position change per keypress
+
+"""
+
+import argparse
+
+import numpy as np
+from pynput.keyboard import Controller, Key, Listener
+
+import robosuite
+from robosuite.robots import SingleArm
+
+
+class KeyboardHandler:
+ def __init__(self, env, delta=0.05):
+ """
+ Store internal state here.
+
+ Args:
+ env (MujocoEnv): Environment to use
+ delta (float): initial joint tuning increment
+ """
+ self.env = env
+ self.delta = delta
+ self.num_robots = len(env.robots)
+ self.active_robot_num = 0
+ self.active_arm_joint = 1
+ self.active_arm = "right" # only relevant for bimanual robots
+ self.current_joints_pos = env.sim.data.qpos[self.active_robot._ref_joint_pos_indexes[: self.num_joints]]
+
+ # make a thread to listen to keyboard and register our callback functions
+ self.listener = Listener(on_press=self.on_press, on_release=self.on_release)
+
+ # start listening
+ self.listener.start()
+
+ def on_press(self, key):
+ """
+ Key handler for key presses.
+
+ Args:
+ key (int): keycode corresponding to the key that was pressed
+ """
+
+ try:
+ if key == Key.up:
+ # Increment the active joint
+ self._update_joint_position(self.active_arm_joint, self.delta)
+ elif key == Key.down:
+ # Decrement the active joint
+ self._update_joint_position(self.active_arm_joint, -self.delta)
+ elif key == Key.right:
+ # Increment the delta value
+ self.delta = min(1.0, self.delta + 0.005)
+ # Print out new value to user
+ print("Delta now = {:.3f}".format(self.delta))
+ elif key == Key.left:
+ # Decrement the delta value
+ self.delta = max(0, self.delta - 0.005)
+ print("Delta now = {:.3f}".format(self.delta))
+ # controls for setting active arm
+ elif key.char == "0":
+ # Notify use that joint indexes are 1-indexed
+ print("Joint Indexes are 1-Indexed. Available joints are 1 - {}".format(self.num_joints))
+ elif key.char == "1":
+ # Make sure range is valid; if so, update this specific joint
+ if self._check_valid_joint(1):
+ self.active_arm_joint = 1
+ # Print out to user
+ print("New joint being tuned: {}".format(self.active_arm_joint))
+ elif key.char == "2":
+ # Make sure range is valid; if so, update this specific joint
+ if self._check_valid_joint(2):
+ self.active_arm_joint = 2
+ # Print out to user
+ print("New joint being tuned: {}".format(self.active_arm_joint))
+ elif key.char == "3":
+ # Make sure range is valid; if so, update this specific joint
+ if self._check_valid_joint(3):
+ self.active_arm_joint = 3
+ # Print out to user
+ print("New joint being tuned: {}".format(self.active_arm_joint))
+ elif key.char == "4":
+ # Make sure range is valid; if so, update this specific joint
+ if self._check_valid_joint(4):
+ self.active_arm_joint = 4
+ # Print out to user
+ print("New joint being tuned: {}".format(self.active_arm_joint))
+ elif key.char == "5":
+ # Make sure range is valid; if so, update this specific joint
+ if self._check_valid_joint(5):
+ self.active_arm_joint = 5
+ # Print out to user
+ print("New joint being tuned: {}".format(self.active_arm_joint))
+ elif key.char == "6":
+ # Make sure range is valid; if so, update this specific joint
+ if self._check_valid_joint(6):
+ self.active_arm_joint = 6
+ # Print out to user
+ print("New joint being tuned: {}".format(self.active_arm_joint))
+ elif key.char == "7":
+ # Make sure range is valid; if so, update this specific joint
+ if self._check_valid_joint(7):
+ self.active_arm_joint = 7
+ # Print out to user
+ print("New joint being tuned: {}".format(self.active_arm_joint))
+ elif key.char == "8":
+ # Make sure range is valid; if so, update this specific joint
+ if self._check_valid_joint(8):
+ self.active_arm_joint = 8
+ # Print out to user
+ print("New joint being tuned: {}".format(self.active_arm_joint))
+ elif key.char == "9":
+ # Make sure range is valid; if so, update this specific joint
+ if self._check_valid_joint(9):
+ self.active_arm_joint = 9
+ # Print out to user
+ print("New joint being tuned: {}".format(self.active_arm_joint))
+ elif key.char == "t":
+ # Toggle active arm
+ self._toggle_arm()
+ elif key.char == "r":
+ # Reset active arm joint qpos to 0
+ self.set_joint_positions(np.zeros(self.num_joints))
+
+ except AttributeError as e:
+ pass
+
+ def on_release(self, key):
+ """
+ Key handler for key releases.
+
+ Args:
+ key: [NOT USED]
+ """
+ pass
+
+ def set_joint_positions(self, qpos):
+ """
+ Automatically sets the joint positions to be the given value
+
+ Args:
+ qpos (np.array): Joint positions to set
+ """
+ self.current_joints_pos = qpos
+ self._update_joint_position(1, 0)
+
+ def _check_valid_joint(self, i):
+ """
+ Checks to make sure joint number request @i is within valid range
+
+ Args:
+ i (int): Index to validate
+
+ Returns:
+ bool: True if index @i is valid, else prints out an error and returns False
+ """
+ if i > self.num_joints:
+ # Print error
+ print("Error: Requested joint {} is out of range; available joints are 1 - {}".format(i, self.num_joints))
+ return False
+ else:
+ return True
+
+ def _toggle_arm(self):
+ """
+ Toggle between arms in the environment to set as current active arm
+ """
+ if isinstance(self.active_robot, SingleArm):
+ self.active_robot_num = (self.active_robot_num + 1) % self.num_robots
+ robot = self.active_robot_num
+ else: # Bimanual case
+ self.active_arm = "left" if self.active_arm == "right" else "right"
+ robot = self.active_arm
+ # Reset joint being controlled to 1
+ self.active_arm_joint = 1
+ # Print out new robot to user
+ print("New robot arm being tuned: {}".format(robot))
+
+ def _update_joint_position(self, i, delta):
+ """
+ Updates specified joint position @i by value @delta from its current position
+ Note: assumes @i is already within the valid joint range
+
+ Args:
+ i (int): Joint index to update
+ delta (float): Increment to alter specific joint by
+ """
+ self.current_joints_pos[i - 1] += delta
+ if isinstance(self.active_robot, SingleArm):
+ robot = self.active_robot_num
+ self.env.sim.data.qpos[self.active_robot._ref_joint_pos_indexes] = self.current_joints_pos
+ else: # Bimanual case
+ robot = self.active_arm
+ if self.active_arm == "right":
+ self.env.sim.data.qpos[
+ self.active_robot._ref_joint_pos_indexes[: self.num_joints]
+ ] = self.current_joints_pos
+ else: # left arm case
+ self.env.sim.data.qpos[
+ self.active_robot._ref_joint_pos_indexes[self.num_joints :]
+ ] = self.current_joints_pos
+ # Print out current joint positions to user
+ print("Robot {} joint qpos: {}".format(robot, self.current_joints_pos))
+
+ @property
+ def active_robot(self):
+ """
+ Returns:
+ Robot: active robot arm currently being tuned
+ """
+ return self.env.robots[self.active_robot_num]
+
+ @property
+ def num_joints(self):
+ """
+ Returns:
+ int: number of joints for the current arm
+ """
+ if isinstance(self.active_robot, SingleArm):
+ return len(self.active_robot.torque_limits[0])
+ else: # Bimanual arm case
+ return int(len(self.active_robot.torque_limits[0]) / 2)
+
+
+def print_command(char, info):
+ """
+ Prints out the command + relevant info entered by user
+
+ Args:
+ char (str): Command entered
+ info (str): Any additional info to print
+ """
+ char += " " * (10 - len(char))
+ print("{}\t{}".format(char, info))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--env", type=str, default="Lift")
+ parser.add_argument("--robots", nargs="+", type=str, default="Panda", help="Which robot(s) to use in the env")
+ parser.add_argument(
+ "--init_qpos", nargs="+", type=float, default=0, help="Initial qpos to use. 0 defaults to all zeros"
+ )
+
+ args = parser.parse_args()
+
+ print(
+ "\nWelcome to the joint tuning script! You will be able to tune the robot\n"
+ "arm joints in the specified environment by using your keyboard. The \n"
+ "controls are printed below:"
+ )
+
+ print("")
+ print_command("Keys", "Command")
+ print_command("1-N", "Active Joint being tuned (N=number of joints for the active arm)")
+ print_command("t", "Toggle between robot arms in the environment")
+ print_command("r", "Reset active arm joints to all 0s")
+ print_command("up/down", "incr/decrement the active joint angle")
+ print_command("right/left", "incr/decrement the delta joint angle per up/down keypress")
+ print("")
+
+ # Setup printing options for numbers
+ np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
+
+ # Define the controller
+ controller_config = robosuite.load_controller_config(default_controller="JOINT_POSITION")
+
+ # make the environment
+ env = robosuite.make(
+ args.env,
+ robots=args.robots,
+ has_renderer=True,
+ has_offscreen_renderer=False,
+ ignore_done=True,
+ use_camera_obs=False,
+ control_freq=20,
+ render_camera=None,
+ controller_configs=controller_config,
+ initialization_noise=None,
+ )
+ env.reset()
+
+ # register callbacks to handle key presses in the viewer
+ key_handler = KeyboardHandler(env=env)
+
+ # Set initial state
+ if type(args.init_qpos) == int and args.init_qpos == 0:
+ # Default to all zeros
+ pass
+ else:
+ key_handler.set_joint_positions(args.init_qpos)
+
+ # just spin to let user interact with window
+ while True:
+ action = np.zeros(env.action_dim)
+ obs, reward, done, _ = env.step(action)
+ env.render()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..70f80caa075d3150eb8346a4aae00fdf6438f499
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/__init__.py
@@ -0,0 +1,3 @@
+from .errors import robosuiteError, XMLError, SimulationError, RandomizationError
+
+from .opencv_renderer import OpenCVRenderer
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/binding_utils.py b/phantom/submodules/phantom-robosuite/robosuite/utils/binding_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd46540e1a316603946cbf118d912d92f0ca604c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/binding_utils.py
@@ -0,0 +1,1177 @@
+"""
+Useful classes for supporting DeepMind MuJoCo binding.
+"""
+
+import gc
+import os
+from tempfile import TemporaryDirectory
+
+# DIRTY HACK copied from mujoco-py - a global lock on rendering
+from threading import Lock
+
+import mujoco
+import numpy as np
+
+_MjSim_render_lock = Lock()
+
+import ctypes
+import ctypes.util
+import os
+import platform
+import subprocess
+
+import robosuite.macros as macros
+
+_SYSTEM = platform.system()
+if _SYSTEM == "Windows":
+ ctypes.WinDLL(os.path.join(os.path.dirname(__file__), "mujoco.dll"))
+
+CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", "")
+if CUDA_VISIBLE_DEVICES != "":
+ MUJOCO_EGL_DEVICE_ID = os.environ.get("MUJOCO_EGL_DEVICE_ID", None)
+ if MUJOCO_EGL_DEVICE_ID is not None:
+ assert MUJOCO_EGL_DEVICE_ID.isdigit() and (
+ MUJOCO_EGL_DEVICE_ID in CUDA_VISIBLE_DEVICES
+ ), "MUJOCO_EGL_DEVICE_ID needs to be set to one of the device id specified in CUDA_VISIBLE_DEVICES"
+
+if macros.MUJOCO_GPU_RENDERING and os.environ.get("MUJOCO_GL", None) not in ["osmesa", "glx"]:
+ # If gpu rendering is specified in macros, then we enforce gpu
+ # option for rendering
+ if _SYSTEM == "Darwin":
+ os.environ["MUJOCO_GL"] = "cgl"
+ else:
+ os.environ["MUJOCO_GL"] = "egl"
+_MUJOCO_GL = os.environ.get("MUJOCO_GL", "").lower().strip()
+if _MUJOCO_GL not in ("disable", "disabled", "off", "false", "0"):
+ _VALID_MUJOCO_GL = ("enable", "enabled", "on", "true", "1", "glfw", "")
+ if _SYSTEM == "Linux":
+ _VALID_MUJOCO_GL += ("glx", "egl", "osmesa")
+ elif _SYSTEM == "Windows":
+ _VALID_MUJOCO_GL += ("wgl",)
+ elif _SYSTEM == "Darwin":
+ _VALID_MUJOCO_GL += ("cgl",)
+ if _MUJOCO_GL not in _VALID_MUJOCO_GL:
+ raise RuntimeError(f"invalid value for environment variable MUJOCO_GL: {_MUJOCO_GL}")
+ if _SYSTEM == "Linux" and _MUJOCO_GL == "osmesa":
+ from robosuite.renderers.context.osmesa_context import OSMesaGLContext as GLContext
+ elif _SYSTEM == "Linux" and _MUJOCO_GL == "egl":
+ from robosuite.renderers.context.egl_context import EGLGLContext as GLContext
+ else:
+ from robosuite.renderers.context.glfw_context import GLFWGLContext as GLContext
+
+
+class MjRenderContext:
+ """
+ Class that encapsulates rendering functionality for a
+ MuJoCo simulation.
+
+ See https://github.com/openai/mujoco-py/blob/4830435a169c1f3e3b5f9b58a7c3d9c39bdf4acb/mujoco_py/mjrendercontext.pyx
+ """
+
+ def __init__(self, sim, offscreen=True, device_id=-1, max_width=640, max_height=480):
+ assert offscreen, "only offscreen supported for now"
+ self.sim = sim
+ self.offscreen = offscreen
+ self.device_id = device_id
+
+ # setup GL context with defaults for now
+ self.gl_ctx = GLContext(max_width=max_width, max_height=max_height, device_id=self.device_id)
+ self.gl_ctx.make_current()
+
+ # Ensure the model data has been updated so that there
+ # is something to render
+ sim.forward()
+ # make sure sim has this context
+ sim.add_render_context(self)
+
+ self.model = sim.model
+ self.data = sim.data
+
+ # create default scene
+ self.scn = mujoco.MjvScene(sim.model._model, maxgeom=1000)
+
+ # camera
+ self.cam = mujoco.MjvCamera()
+ self.cam.fixedcamid = 0
+ self.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
+
+ # options for visual / collision mesh can be set externally, e.g. vopt.geomgroup[0], vopt.geomgroup[1]
+ self.vopt = mujoco.MjvOption()
+
+ self.pert = mujoco.MjvPerturb()
+ self.pert.active = 0
+ self.pert.select = 0
+ self.pert.skinselect = -1
+
+ # self._markers = []
+ # self._overlay = {}
+
+ self._set_mujoco_context_and_buffers()
+
+ def _set_mujoco_context_and_buffers(self):
+ self.con = mujoco.MjrContext(self.model._model, mujoco.mjtFontScale.mjFONTSCALE_150)
+ mujoco.mjr_setBuffer(mujoco.mjtFramebuffer.mjFB_OFFSCREEN, self.con)
+
+ def update_offscreen_size(self, width, height):
+ if (width != self.con.offWidth) or (height != self.con.offHeight):
+ self.model.vis.global_.offwidth = width
+ self.model.vis.global_.offheight = height
+ self.con.free()
+ del self.con
+ self._set_mujoco_context_and_buffers()
+
+ def upload_texture(self, tex_id):
+ """Uploads given texture to the GPU"""
+ self.gl_ctx.make_current()
+ mujoco.mjr_uploadTexture(self.model, self.con, tex_id)
+
+ def render(self, width, height, camera_id=None, segmentation=False):
+ viewport = mujoco.MjrRect(0, 0, width, height)
+
+ # if self.sim.render_callback is not None:
+ # self.sim.render_callback(self.sim, self)
+
+ # update width and height of rendering context if necessary
+ if width > self.con.offWidth or height > self.con.offHeight:
+ new_width = max(width, self.model.vis.global_.offwidth)
+ new_height = max(height, self.model.vis.global_.offheight)
+ self.update_offscreen_size(new_width, new_height)
+
+ if camera_id is not None:
+ if camera_id == -1:
+ self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
+ else:
+ self.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
+ self.cam.fixedcamid = camera_id
+
+ mujoco.mjv_updateScene(
+ self.model._model, self.data._data, self.vopt, self.pert, self.cam, mujoco.mjtCatBit.mjCAT_ALL, self.scn
+ )
+
+ if segmentation:
+ self.scn.flags[mujoco.mjtRndFlag.mjRND_SEGMENT] = 1
+ self.scn.flags[mujoco.mjtRndFlag.mjRND_IDCOLOR] = 1
+
+ # for marker_params in self._markers:
+ # self._add_marker_to_scene(marker_params)
+
+ mujoco.mjr_render(viewport=viewport, scn=self.scn, con=self.con)
+ # for gridpos, (text1, text2) in self._overlay.items():
+ # mjr_overlay(const.FONTSCALE_150, gridpos, rect, text1.encode(), text2.encode(), &self._con)
+
+ if segmentation:
+ self.scn.flags[mujoco.mjtRndFlag.mjRND_SEGMENT] = 0
+ self.scn.flags[mujoco.mjtRndFlag.mjRND_IDCOLOR] = 0
+
+ def read_pixels(self, width, height, depth=False, segmentation=False):
+ viewport = mujoco.MjrRect(0, 0, width, height)
+ rgb_img = np.empty((height, width, 3), dtype=np.uint8)
+ depth_img = np.empty((height, width), dtype=np.float32) if depth else None
+
+ mujoco.mjr_readPixels(rgb=rgb_img, depth=depth_img, viewport=viewport, con=self.con)
+
+ ret_img = rgb_img
+ if segmentation:
+ seg_img = rgb_img[:, :, 0] + rgb_img[:, :, 1] * (2**8) + rgb_img[:, :, 2] * (2**16)
+ seg_img[seg_img >= (self.scn.ngeom + 1)] = 0
+ seg_ids = np.full((self.scn.ngeom + 1, 2), fill_value=-1, dtype=np.int32)
+
+ for i in range(self.scn.ngeom):
+ geom = self.scn.geoms[i]
+ if geom.segid != -1:
+ seg_ids[geom.segid + 1, 0] = geom.objtype
+ seg_ids[geom.segid + 1, 1] = geom.objid
+ ret_img = seg_ids[seg_img]
+
+ if depth:
+ return (ret_img, depth_img)
+ else:
+ return ret_img
+
+ def upload_texture(self, tex_id):
+ """Uploads given texture to the GPU."""
+ self.gl_ctx.make_current()
+ mujoco.mjr_uploadTexture(self.model, self.con, tex_id)
+
+ def __del__(self):
+ # free mujoco rendering context and GL rendering context
+ self.con.free()
+ self.gl_ctx.free()
+ del self.con
+ del self.gl_ctx
+ del self.scn
+ del self.cam
+ del self.vopt
+ del self.pert
+
+
+class MjRenderContextOffscreen(MjRenderContext):
+ def __init__(self, sim, device_id, max_width=640, max_height=480):
+ super().__init__(sim, offscreen=True, device_id=device_id, max_width=max_width, max_height=max_height)
+
+
+class MjSimState:
+ """
+ A mujoco simulation state.
+ """
+
+ def __init__(self, time, qpos, qvel):
+ self.time = time
+ self.qpos = qpos
+ self.qvel = qvel
+
+ @classmethod
+ def from_flattened(cls, array, sim):
+ """
+ Takes flat mjstate array and MjSim instance and
+ returns MjSimState.
+ """
+ idx_time = 0
+ idx_qpos = idx_time + 1
+ idx_qvel = idx_qpos + sim.model.nq
+
+ time = array[idx_time]
+ qpos = array[idx_qpos : idx_qpos + sim.model.nq]
+ qvel = array[idx_qvel : idx_qvel + sim.model.nv]
+ assert sim.model.na == 0
+
+ return cls(time=time, qpos=qpos, qvel=qvel)
+
+ def flatten(self):
+ return np.concatenate([[self.time], self.qpos, self.qvel], axis=0)
+
+
+class _MjModelMeta(type):
+ """
+ Metaclass which allows MjModel below to delegate to mujoco.MjModel.
+
+ Taken from dm_control: https://github.com/deepmind/dm_control/blob/main/dm_control/mujoco/wrapper/core.py#L244
+ """
+
+ def __new__(cls, name, bases, dct):
+ for attr in dir(mujoco.MjModel):
+ if not attr.startswith("_"):
+ if attr not in dct:
+ # pylint: disable=protected-access
+ fget = lambda self, attr=attr: getattr(self._model, attr)
+ fset = lambda self, value, attr=attr: setattr(self._model, attr, value)
+ # pylint: enable=protected-access
+ dct[attr] = property(fget, fset)
+ return super().__new__(cls, name, bases, dct)
+
+
+class MjModel(metaclass=_MjModelMeta):
+ """Wrapper class for a MuJoCo 'mjModel' instance.
+ MjModel encapsulates features of the model that are expected to remain
+ constant. It also contains simulation and visualization options which may be
+ changed occasionally, although this is done explicitly by the user.
+ """
+
+ _HAS_DYNAMIC_ATTRIBUTES = True
+
+ def __init__(self, model_ptr):
+ """Creates a new MjModel instance from a mujoco.MjModel."""
+ self._model = model_ptr
+
+ # make useful mappings such as _body_name2id and _body_id2name
+ self.make_mappings()
+
+ @classmethod
+ def from_xml_path(cls, xml_path):
+ """Creates an MjModel instance from a path to a model XML file."""
+ model_ptr = _get_model_ptr_from_xml(xml_path=xml_path)
+ return cls(model_ptr)
+
+ def __del__(self):
+ # free mujoco model
+ del self._model
+
+ """
+ Some methods supported by sim.model in mujoco-py.
+ Copied from https://github.com/openai/mujoco-py/blob/ab86d331c9a77ae412079c6e58b8771fe63747fc/mujoco_py/generated/wrappers.pxi#L2611
+ """
+
+ def _extract_mj_names(self, name_adr, num_obj, obj_type):
+ """
+ See https://github.com/openai/mujoco-py/blob/ab86d331c9a77ae412079c6e58b8771fe63747fc/mujoco_py/generated/wrappers.pxi#L1127
+ """
+
+ ### TODO: fix this to use @name_adr like mujoco-py - more robust than assuming IDs are continuous ###
+
+ # objects don't need to be named in the XML, so name might be None
+ id2name = {i: None for i in range(num_obj)}
+ name2id = {}
+ for i in range(num_obj):
+ name = mujoco.mj_id2name(self._model, obj_type, i)
+ name2id[name] = i
+ id2name[i] = name
+
+ # # objects don't need to be named in the XML, so name might be None
+ # id2name = { i: None for i in range(num_obj) }
+ # name2id = {}
+ # for i in range(num_obj):
+ # name = self.model.names[name_adr[i]]
+ # decoded_name = name.decode()
+ # if decoded_name:
+ # obj_id = mujoco.mj_name2id(self.model, obj_type, name)
+ # assert (0 <= obj_id < num_obj) and (id2name[obj_id] is None)
+ # name2id[decoded_name] = obj_id
+ # id2name[obj_id] = decoded_name
+
+ # sort names by increasing id to keep order deterministic
+ return tuple(id2name[nid] for nid in sorted(name2id.values())), name2id, id2name
+
+ def make_mappings(self):
+ """
+ Make some useful internal mappings that mujoco-py supported.
+ """
+ p = self
+ self.body_names, self._body_name2id, self._body_id2name = self._extract_mj_names(
+ p.name_bodyadr, p.nbody, mujoco.mjtObj.mjOBJ_BODY
+ )
+ self.joint_names, self._joint_name2id, self._joint_id2name = self._extract_mj_names(
+ p.name_jntadr, p.njnt, mujoco.mjtObj.mjOBJ_JOINT
+ )
+ self.geom_names, self._geom_name2id, self._geom_id2name = self._extract_mj_names(
+ p.name_geomadr, p.ngeom, mujoco.mjtObj.mjOBJ_GEOM
+ )
+ self.site_names, self._site_name2id, self._site_id2name = self._extract_mj_names(
+ p.name_siteadr, p.nsite, mujoco.mjtObj.mjOBJ_SITE
+ )
+ self.light_names, self._light_name2id, self._light_id2name = self._extract_mj_names(
+ p.name_lightadr, p.nlight, mujoco.mjtObj.mjOBJ_LIGHT
+ )
+ self.camera_names, self._camera_name2id, self._camera_id2name = self._extract_mj_names(
+ p.name_camadr, p.ncam, mujoco.mjtObj.mjOBJ_CAMERA
+ )
+ self.actuator_names, self._actuator_name2id, self._actuator_id2name = self._extract_mj_names(
+ p.name_actuatoradr, p.nu, mujoco.mjtObj.mjOBJ_ACTUATOR
+ )
+ self.sensor_names, self._sensor_name2id, self._sensor_id2name = self._extract_mj_names(
+ p.name_sensoradr, p.nsensor, mujoco.mjtObj.mjOBJ_SENSOR
+ )
+ self.tendon_names, self._tendon_name2id, self._tendon_id2name = self._extract_mj_names(
+ p.name_tendonadr, p.ntendon, mujoco.mjtObj.mjOBJ_TENDON
+ )
+ self.mesh_names, self._mesh_name2id, self._mesh_id2name = self._extract_mj_names(
+ p.name_meshadr, p.nmesh, mujoco.mjtObj.mjOBJ_MESH
+ )
+
+ def body_id2name(self, id):
+ """Get body name from mujoco body id."""
+ if id not in self._body_id2name:
+ raise ValueError("No body with id %d exists." % id)
+ return self._body_id2name[id]
+
+ def body_name2id(self, name):
+ """Get body id from mujoco body name."""
+ if name not in self._body_name2id:
+ raise ValueError('No "body" with name %s exists. Available "body" names = %s.' % (name, self.body_names))
+ return self._body_name2id[name]
+
+ def joint_id2name(self, id):
+ """Get joint name from mujoco joint id."""
+ if id not in self._joint_id2name:
+ raise ValueError("No joint with id %d exists." % id)
+ return self._joint_id2name[id]
+
+ def joint_name2id(self, name):
+ """Get joint id from joint name."""
+ if name not in self._joint_name2id:
+ raise ValueError('No "joint" with name %s exists. Available "joint" names = %s.' % (name, self.joint_names))
+ return self._joint_name2id[name]
+
+ def geom_id2name(self, id):
+ """Get geom name from geom id."""
+ if id not in self._geom_id2name:
+ raise ValueError("No geom with id %d exists." % id)
+ return self._geom_id2name[id]
+
+ def geom_name2id(self, name):
+ """Get geom id from geom name."""
+ if name not in self._geom_name2id:
+ raise ValueError('No "geom" with name %s exists. Available "geom" names = %s.' % (name, self.geom_names))
+ return self._geom_name2id[name]
+
+ def site_id2name(self, id):
+ """Get site name from site id."""
+ if id not in self._site_id2name:
+ raise ValueError("No site with id %d exists." % id)
+ return self._site_id2name[id]
+
+ def site_name2id(self, name):
+ """Get site id from site name."""
+ if name not in self._site_name2id:
+ raise ValueError('No "site" with name %s exists. Available "site" names = %s.' % (name, self.site_names))
+ return self._site_name2id[name]
+
+ def light_id2name(self, id):
+ """Get light name from light id."""
+ if id not in self._light_id2name:
+ raise ValueError("No light with id %d exists." % id)
+ return self._light_id2name[id]
+
+ def light_name2id(self, name):
+ """Get light id from light name."""
+ if name not in self._light_name2id:
+ raise ValueError('No "light" with name %s exists. Available "light" names = %s.' % (name, self.light_names))
+ return self._light_name2id[name]
+
+ def camera_id2name(self, id):
+ """Get camera name from camera id."""
+ if id not in self._camera_id2name:
+ raise ValueError("No camera with id %d exists." % id)
+ return self._camera_id2name[id]
+
+ def camera_name2id(self, name):
+ """Get camera id from camera name."""
+ if name not in self._camera_name2id:
+ raise ValueError(
+ 'No "camera" with name %s exists. Available "camera" names = %s.' % (name, self.camera_names)
+ )
+ return self._camera_name2id[name]
+
+ def actuator_id2name(self, id):
+ """Get actuator name from actuator id."""
+ if id not in self._actuator_id2name:
+ raise ValueError("No actuator with id %d exists." % id)
+ return self._actuator_id2name[id]
+
+ def actuator_name2id(self, name):
+ """Get actuator id from actuator name."""
+ if name not in self._actuator_name2id:
+ raise ValueError(
+ 'No "actuator" with name %s exists. Available "actuator" names = %s.' % (name, self.actuator_names)
+ )
+ return self._actuator_name2id[name]
+
+ def sensor_id2name(self, id):
+ """Get sensor name from sensor id."""
+ if id not in self._sensor_id2name:
+ raise ValueError("No sensor with id %d exists." % id)
+ return self._sensor_id2name[id]
+
+ def sensor_name2id(self, name):
+ """Get sensor id from sensor name."""
+ if name not in self._sensor_name2id:
+ raise ValueError(
+ 'No "sensor" with name %s exists. Available "sensor" names = %s.' % (name, self.sensor_names)
+ )
+ return self._sensor_name2id[name]
+
+ def tendon_id2name(self, id):
+ """Get tendon name from tendon id."""
+ if id not in self._tendon_id2name:
+ raise ValueError("No tendon with id %d exists." % id)
+ return self._tendon_id2name[id]
+
+ def tendon_name2id(self, name):
+ """Get tendon id from tendon name."""
+ if name not in self._tendon_name2id:
+ raise ValueError(
+ 'No "tendon" with name %s exists. Available "tendon" names = %s.' % (name, self.tendon_names)
+ )
+ return self._tendon_name2id[name]
+
+ def mesh_id2name(self, id):
+ """Get mesh name from mesh id."""
+ if id not in self._mesh_id2name:
+ raise ValueError("No mesh with id %d exists." % id)
+ return self._mesh_id2name[id]
+
+ def mesh_name2id(self, name):
+ """Get mesh id from mesh name."""
+ if name not in self._mesh_name2id:
+ raise ValueError('No "mesh" with name %s exists. Available "mesh" names = %s.' % (name, self.mesh_names))
+ return self._mesh_name2id[name]
+
+ # def userdata_id2name(self, id):
+ # if id not in self._userdata_id2name:
+ # raise ValueError("No userdata with id %d exists." % id)
+ # return self._userdata_id2name[id]
+
+ # def userdata_name2id(self, name):
+ # if name not in self._userdata_name2id:
+ # raise ValueError("No \"userdata\" with name %s exists. Available \"userdata\" names = %s." % (name, self.userdata_names))
+ # return self._userdata_name2id[name]
+
+ def get_xml(self):
+ with TemporaryDirectory() as td:
+ filename = os.path.join(td, "model.xml")
+ ret = mujoco.mj_saveLastXML(filename.encode(), self._model)
+ return open(filename).read()
+
+ def get_joint_qpos_addr(self, name):
+ """
+ See https://github.com/openai/mujoco-py/blob/ab86d331c9a77ae412079c6e58b8771fe63747fc/mujoco_py/generated/wrappers.pxi#L1178
+
+ Returns the qpos address for given joint.
+ Returns:
+ - address (int, tuple): returns int address if 1-dim joint, otherwise
+ returns the a (start, end) tuple for pos[start:end] access.
+ """
+ joint_id = self.joint_name2id(name)
+ joint_type = self.jnt_type[joint_id]
+ joint_addr = self.jnt_qposadr[joint_id]
+ if joint_type == mujoco.mjtJoint.mjJNT_FREE:
+ ndim = 7
+ elif joint_type == mujoco.mjtJoint.mjJNT_BALL:
+ ndim = 4
+ else:
+ assert joint_type in (mujoco.mjtJoint.mjJNT_HINGE, mujoco.mjtJoint.mjJNT_SLIDE)
+ ndim = 1
+
+ if ndim == 1:
+ return joint_addr
+ else:
+ return (joint_addr, joint_addr + ndim)
+
+ def get_joint_qvel_addr(self, name):
+ """
+ See https://github.com/openai/mujoco-py/blob/ab86d331c9a77ae412079c6e58b8771fe63747fc/mujoco_py/generated/wrappers.pxi#L1202
+
+ Returns the qvel address for given joint.
+ Returns:
+ - address (int, tuple): returns int address if 1-dim joint, otherwise
+ returns the a (start, end) tuple for vel[start:end] access.
+ """
+ joint_id = self.joint_name2id(name)
+ joint_type = self.jnt_type[joint_id]
+ joint_addr = self.jnt_dofadr[joint_id]
+ if joint_type == mujoco.mjtJoint.mjJNT_FREE:
+ ndim = 6
+ elif joint_type == mujoco.mjtJoint.mjJNT_BALL:
+ ndim = 3
+ else:
+ assert joint_type in (mujoco.mjtJoint.mjJNT_HINGE, mujoco.mjtJoint.mjJNT_SLIDE)
+ ndim = 1
+
+ if ndim == 1:
+ return joint_addr
+ else:
+ return (joint_addr, joint_addr + ndim)
+
+
+class _MjDataMeta(type):
+ """
+ Metaclass which allows MjData below to delegate to mujoco.MjData.
+
+ Taken from dm_control.
+ """
+
+ def __new__(cls, name, bases, dct):
+ for attr in dir(mujoco.MjData):
+ if not attr.startswith("_"):
+ if attr not in dct:
+ # pylint: disable=protected-access
+ fget = lambda self, attr=attr: getattr(self._data, attr)
+ fset = lambda self, value, attr=attr: setattr(self._data, attr, value)
+ # pylint: enable=protected-access
+ dct[attr] = property(fget, fset)
+ return super().__new__(cls, name, bases, dct)
+
+
+class MjData(metaclass=_MjDataMeta):
+ """Wrapper class for a MuJoCo 'mjData' instance.
+ MjData contains all of the dynamic variables and intermediate results produced
+ by the simulation. These are expected to change on each simulation timestep.
+ The properties without docstrings are defined in mujoco source code from https://github.com/deepmind/mujoco/blob/062cb53a4a14b2a7a900453613a7ce498728f9d8/include/mujoco/mjdata.h#L126.
+ """
+
+ def __init__(self, model):
+ """Construct a new MjData instance.
+ Args:
+ model: An MjModel instance.
+ """
+ self._model = model
+ self._data = mujoco.MjData(model._model)
+
+ @property
+ def model(self):
+ """The parent MjModel for this MjData instance."""
+ return self._model
+
+ def __del__(self):
+ # free mujoco data
+ del self._data
+
+ """
+ Some methods supported by sim.data in mujoco-py.
+ Copied from https://github.com/openai/mujoco-py/blob/ab86d331c9a77ae412079c6e58b8771fe63747fc/mujoco_py/generated/wrappers.pxi#L2611
+ """
+
+ @property
+ def body_xpos(self):
+ """
+ Note: mujoco-py used to support sim.data.body_xpos but DM mujoco bindings requires sim.data.xpos,
+ so we explicitly expose this as a property
+ """
+ return self._data.xpos
+
+ @property
+ def body_xquat(self):
+ """
+ Note: mujoco-py used to support sim.data.body_xquat but DM mujoco bindings requires sim.data.xquat,
+ so we explicitly expose this as a property
+ """
+ return self._data.xquat
+
+ @property
+ def body_xmat(self):
+ """
+ Note: mujoco-py used to support sim.data.body_xmat but DM mujoco bindings requires sim.data.xmax,
+ so we explicitly expose this as a property
+ """
+ return self._data.xmat
+
+ def get_body_xpos(self, name):
+ """
+ Query cartesian position of a mujoco body using a name string.
+
+ Args:
+ name (str): The name of a mujoco body
+ Returns:
+ xpos (np.ndarray): The xpos value of the mujoco body
+ """
+ bid = self.model.body_name2id(name)
+ return self.xpos[bid]
+
+ def get_body_xquat(self, name):
+ """
+ Query the rotation of a mujoco body in quaternion (in wxyz convention) using a name string.
+
+ Args:
+ name (str): The name of a mujoco body
+ Returns:
+ xquat (np.ndarray): The xquat value of the mujoco body
+ """
+ bid = self.model.body_name2id(name)
+ return self.xquat[bid]
+
+ def get_body_xmat(self, name):
+ """
+ Query the rotation of a mujoco body in a rotation matrix using a name string.
+
+ Args:
+ name (str): The name of a mujoco body
+ Returns:
+ xmat (np.ndarray): The xmat value of the mujoco body
+ """
+ bid = self.model.body_name2id(name)
+ return self.xmat[bid].reshape((3, 3))
+
+ def get_body_jacp(self, name):
+ """
+ Query the position jacobian of a mujoco body using a name string.
+
+ Args:
+ name (str): The name of a mujoco body
+ Returns:
+ jacp (np.ndarray): The jacp value of the mujoco body
+ """
+ bid = self.model.body_name2id(name)
+ jacp = np.zeros((3, self.model.nv))
+ mujoco.mj_jacBody(self.model._model, self._data, jacp, None, bid)
+ return jacp
+
+ def get_body_jacr(self, name):
+ """
+ Query the rotation jacobian of a mujoco body using a name string.
+
+ Args:
+ name (str): The name of a mujoco body
+ Returns:
+ jacr (np.ndarray): The jacr value of the mujoco body
+ """
+ bid = self.model.body_name2id(name)
+ jacr = np.zeros((3, self.model.nv))
+ mujoco.mj_jacBody(self.model._model, self._data, None, jacr, bid)
+ return jacr
+
+ def get_body_xvelp(self, name):
+ """
+ Query the translational velocity of a mujoco body using a name string.
+
+ Args:
+ name (str): The name of a mujoco body
+ Returns:
+ xvelp (np.ndarray): The translational velocity of the mujoco body.
+ """
+ jacp = self.get_body_jacp(name)
+ xvelp = np.dot(jacp, self.qvel)
+ return xvelp
+
+ def get_body_xvelr(self, name):
+ """
+ Query the rotational velocity of a mujoco body using a name string.
+
+ Args:
+ name (str): The name of a mujoco body
+ Returns:
+ xvelr (np.ndarray): The rotational velocity of the mujoco body.
+ """
+ jacr = self.get_body_jacr(name)
+ xvelr = np.dot(jacr, self.qvel)
+ return xvelr
+
+ def get_geom_xpos(self, name):
+ """
+ Query the cartesian position of a mujoco geom using a name string.
+
+ Args:
+ name (str): The name of a mujoco geom
+ Returns:
+ geom_xpos (np.ndarray): The cartesian position of the mujoco body.
+ """
+ gid = self.model.geom_name2id(name)
+ return self.geom_xpos[gid]
+
+ def get_geom_xmat(self, name):
+ """
+ Query the rotation of a mujoco geom in a rotation matrix using a name string.
+
+ Args:
+ name (str): The name of a mujoco geom
+ Returns:
+ geom_xmat (np.ndarray): The 3x3 rotation matrix of the mujoco geom.
+ """
+ gid = self.model.geom_name2id(name)
+ return self.geom_xmat[gid].reshape((3, 3))
+
+ def get_geom_jacp(self, name):
+ """
+ Query the position jacobian of a mujoco geom using a name string.
+
+ Args:
+ name (str): The name of a mujoco geom
+ Returns:
+ jacp (np.ndarray): The jacp value of the mujoco geom
+ """
+ gid = self.model.geom_name2id(name)
+ jacp = np.zeros((3, self.model.nv))
+ mujoco.mj_jacGeom(self.model._model, self._data, jacp, None, gid)
+ return jacp
+
+ def get_geom_jacr(self, name):
+ """
+ Query the rotation jacobian of a mujoco geom using a name string.
+
+ Args:
+ name (str): The name of a mujoco geom
+ Returns:
+ jacr (np.ndarray): The jacr value of the mujoco geom
+ """
+ gid = self.model.geom_name2id(name)
+ jacv = np.zeros((3, self.model.nv))
+ mujoco.mj_jacGeom(self.model._model, self._data, None, jacv, gid)
+ return jacr
+
+ def get_geom_xvelp(self, name):
+ """
+ Query the translational velocity of a mujoco geom using a name string.
+
+ Args:
+ name (str): The name of a mujoco geom
+ Returns:
+ xvelp (np.ndarray): The translational velocity of the mujoco geom
+ """
+ jacp = self.get_geom_jacp(name)
+ xvelp = np.dot(jacp, self.qvel)
+ return xvelp
+
+ def get_geom_xvelr(self, name):
+ """
+ Query the rotational velocity of a mujoco geom using a name string.
+
+ Args:
+ name (str): The name of a mujoco geom
+ Returns:
+ xvelr (np.ndarray): The rotational velocity of the mujoco geom
+ """
+ jacr = self.get_geom_jacr(name)
+ xvelr = np.dot(jacr, self.qvel)
+ return xvelr
+
+ def get_site_xpos(self, name):
+ """
+ Query the cartesian position of a mujoco site using a name string.
+
+ Args:
+ name (str): The name of a mujoco site
+ Returns:
+ site_xpos (np.ndarray): The carteisan position of the mujoco site
+ """
+ sid = self.model.site_name2id(name)
+ return self.site_xpos[sid]
+
+ def get_site_xmat(self, name):
+ """
+ Query the rotation of a mujoco site in a rotation matrix using a name string.
+
+ Args:
+ name (str): The name of a mujoco site
+ Returns:
+ site_xmat (np.ndarray): The 3x3 rotation matrix of the mujoco site.
+ """
+ sid = self.model.site_name2id(name)
+ return self.site_xmat[sid].reshape((3, 3))
+
+ def get_site_jacp(self, name):
+ """
+ Query the position jacobian of a mujoco site using a name string.
+
+ Args:
+ name (str): The name of a mujoco site
+ Returns:
+ jacp (np.ndarray): The jacp value of the mujoco site
+ """
+ sid = self.model.site_name2id(name)
+ jacp = np.zeros((3, self.model.nv))
+ mujoco.mj_jacSite(self.model._model, self._data, jacp, None, sid)
+ return jacp
+
+ def get_site_jacr(self, name):
+ """
+ Query the rotation jacobian of a mujoco site using a name string.
+
+ Args:
+ name (str): The name of a mujoco site
+ Returns:
+ jacr (np.ndarray): The jacr value of the mujoco site
+ """
+ sid = self.model.site_name2id(name)
+ jacr = np.zeros((3, self.model.nv))
+ mujoco.mj_jacSite(self.model._model, self._data, None, jacr, sid)
+ return jacr
+
+ def get_site_xvelp(self, name):
+ """
+ Query the translational velocity of a mujoco site using a name string.
+
+ Args:
+ name (str): The name of a mujoco site
+ Returns:
+ xvelp (np.ndarray): The translational velocity of the mujoco site
+ """
+ jacp = self.get_site_jacp(name)
+ xvelp = np.dot(jacp, self.qvel)
+ return xvelp
+
+ def get_site_xvelr(self, name):
+ """
+ Query the rotational velocity of a mujoco site using a name string.
+
+ Args:
+ name (str): The name of a mujoco site
+ Returns:
+ xvelr (np.ndarray): The rotational velocity of the mujoco site
+ """
+ jacr = self.get_site_jacr(name)
+ xvelr = np.dot(jacr, self.qvel)
+ return xvelr
+
+ def get_camera_xpos(self, name):
+ """
+ Get the cartesian position of a camera using name
+
+ Args:
+ name (str): The name of a camera
+ Returns:
+ cam_xpos (np.ndarray): The cartesian position of a camera
+ """
+ cid = self.model.camera_name2id(name)
+ return self.cam_xpos[cid]
+
+ def get_camera_xmat(self, name):
+ """
+ Get the rotation of a camera in a rotation matrix using name
+
+ Args:
+ name (str): The name of a camera
+ Returns:
+ cam_xmat (np.ndarray): The 3x3 rotation matrix of a camera
+ """
+ cid = self.model.camera_name2id(name)
+ return self.cam_xmat[cid].reshape((3, 3))
+
+ def get_light_xpos(self, name):
+ """
+ Get cartesian position of a light source
+
+ Args:
+ name (str): The name of a lighting source
+ Returns:
+ light_xpos (np.ndarray): The cartesian position of the light source
+ """
+ lid = self.model.light_name2id(name)
+ return self.light_xpos[lid]
+
+ def get_light_xdir(self, name):
+ """
+ Get the direction of a light source using name
+
+ Args:
+ name (str): The name of a light
+ Returns:
+ light_xdir (np.ndarray): The direction vector of the lightsource
+ """
+ lid = self.model.light_name2id(name)
+ return self.light_xdir[lid]
+
+ def get_sensor(self, name):
+ """
+ Get the data of a sensor using name
+
+ Args:
+ name (str): The name of a sensor
+ Returns:
+ sensordata (np.ndarray): The sensor data vector
+ """
+ sid = self.model.sensor_name2id(name)
+ return self.sensordata[sid]
+
+ def get_mocap_pos(self, name):
+ """
+ Get the position of a mocap body using name.
+
+ Args:
+ name (str): The name of a joint
+ Returns:
+ mocap_pos (np.ndarray): The current position of a mocap body.
+ """
+ body_id = self.model.body_name2id(name)
+ mocap_id = self.model.body_mocapid[body_id]
+ return self.mocap_pos[mocap_id]
+
+ def set_mocap_pos(self, name, value):
+ """
+ Set the quaternion of a mocap body using name.
+
+ Args:
+ name (str): The name of a joint
+ value (float): The desired joint position of a mocap body.
+ """
+ body_id = self.model.body_name2id(name)
+ mocap_id = self.model.body_mocapid[body_id]
+ self.mocap_pos[mocap_id] = value
+
+ def get_mocap_quat(self, name):
+ """
+ Get the quaternion of a mocap body using name.
+
+ Args:
+ name (str): The name of a joint
+ Returns:
+ mocap_quat (np.ndarray): The current quaternion of a mocap body.
+ """
+ body_id = self.model.body_name2id(name)
+ mocap_id = self.model.body_mocapid[body_id]
+ return self.mocap_quat[mocap_id]
+
+ def set_mocap_quat(self, name, value):
+ """
+ Set the quaternion of a mocap body using name.
+
+ Args:
+ name (str): The name of a joint
+ value (float): The desired joint quaternion of a mocap body.
+ """
+ body_id = self.model.body_name2id(name)
+ mocap_id = self.model.body_mocapid[body_id]
+ self.mocap_quat[mocap_id] = value
+
+ def get_joint_qpos(self, name):
+ """
+ Get the position of a joint using name.
+
+ Args:
+ name (str): The name of a joint
+
+ Returns:
+ qpos (np.ndarray): The current position of a joint.
+ """
+ addr = self.model.get_joint_qpos_addr(name)
+ if isinstance(addr, (int, np.int32, np.int64)):
+ return self.qpos[addr]
+ else:
+ start_i, end_i = addr
+ return self.qpos[start_i:end_i]
+
+ def set_joint_qpos(self, name, value):
+ """
+ Set the velocities of a joint using name.
+
+ Args:
+ name (str): The name of a joint
+ value (float): The desired joint velocity of a joint.
+ """
+ addr = self.model.get_joint_qpos_addr(name)
+ if isinstance(addr, (int, np.int32, np.int64)):
+ self.qpos[addr] = value
+ else:
+ start_i, end_i = addr
+ value = np.array(value)
+ assert value.shape == (end_i - start_i,), "Value has incorrect shape %s: %s" % (name, value)
+ self.qpos[start_i:end_i] = value
+
+ def get_joint_qvel(self, name):
+ """
+ Get the velocity of a joint using name.
+
+ Args:
+ name (str): The name of a joint
+
+ Returns:
+ qvel (np.ndarray): The current velocity of a joint.
+ """
+ addr = self.model.get_joint_qvel_addr(name)
+ if isinstance(addr, (int, np.int32, np.int64)):
+ return self.qvel[addr]
+ else:
+ start_i, end_i = addr
+ return self.qvel[start_i:end_i]
+
+ def set_joint_qvel(self, name, value):
+ """
+ Set the velocities of a mjo using name.
+
+ Args:
+ name (str): The name of a joint
+ value (float): The desired joint velocity of a joint.
+ """
+ addr = self.model.get_joint_qvel_addr(name)
+ if isinstance(addr, (int, np.int32, np.int64)):
+ self.qvel[addr] = value
+ else:
+ start_i, end_i = addr
+ value = np.array(value)
+ assert value.shape == (end_i - start_i,), "Value has incorrect shape %s: %s" % (name, value)
+ self.qvel[start_i:end_i] = value
+
+
+class MjSim:
+ """
+ Meant to somewhat replicate functionality in mujoco-py's MjSim object
+ (see https://github.com/openai/mujoco-py/blob/master/mujoco_py/mjsim.pyx).
+ """
+
+ def __init__(self, model):
+ """
+ Args:
+ model: should be an MjModel instance created via a factory function
+ such as mujoco.MjModel.from_xml_string(xml)
+ """
+ self.model = MjModel(model)
+ self.data = MjData(self.model)
+
+ # offscreen render context object
+ self._render_context_offscreen = None
+
+ @classmethod
+ def from_xml_string(cls, xml):
+ model = mujoco.MjModel.from_xml_string(xml)
+ return cls(model)
+
+ @classmethod
+ def from_xml_file(cls, xml_file):
+ f = open(xml_file, "r")
+ xml = f.read()
+ f.close()
+ return cls.from_xml_string(xml)
+
+ def reset(self):
+ """Reset simulation."""
+ mujoco.mj_resetData(self.model._model, self.data._data)
+
+ def forward(self):
+ """Forward call to synchronize derived quantities."""
+ mujoco.mj_forward(self.model._model, self.data._data)
+
+ def step(self, with_udd=True):
+ """Step simulation."""
+ mujoco.mj_step(self.model._model, self.data._data)
+
+ def render(
+ self,
+ width=None,
+ height=None,
+ *,
+ camera_name=None,
+ depth=False,
+ mode="offscreen",
+ device_id=-1,
+ segmentation=False,
+ ):
+ """
+ Renders view from a camera and returns image as an `numpy.ndarray`.
+ Args:
+ - width (int): desired image width.
+ - height (int): desired image height.
+ - camera_name (str): name of camera in model. If None, the free
+ camera will be used.
+ - depth (bool): if True, also return depth buffer
+ - device (int): device to use for rendering (only for GPU-backed
+ rendering).
+ Returns:
+ - rgb (uint8 array): image buffer from camera
+ - depth (float array): depth buffer from camera (only returned
+ if depth=True)
+ """
+ if camera_name is None:
+ camera_id = None
+ else:
+ camera_id = self.model.camera_name2id(camera_name)
+
+ assert mode == "offscreen", "only offscreen supported for now"
+ assert self._render_context_offscreen is not None
+ with _MjSim_render_lock:
+ self._render_context_offscreen.render(
+ width=width, height=height, camera_id=camera_id, segmentation=segmentation
+ )
+ return self._render_context_offscreen.read_pixels(width, height, depth=depth, segmentation=segmentation)
+
+ def add_render_context(self, render_context):
+ assert render_context.offscreen
+ if self._render_context_offscreen is not None:
+ # free context
+ del self._render_context_offscreen
+ self._render_context_offscreen = render_context
+
+ def get_state(self):
+ """Return MjSimState instance for current state."""
+ return MjSimState(
+ time=self.data.time,
+ qpos=np.copy(self.data.qpos),
+ qvel=np.copy(self.data.qvel),
+ )
+
+ def set_state(self, value):
+ """
+ Set internal state from MjSimState instance. Should
+ call @forward afterwards to synchronize derived quantities.
+ """
+ self.data.time = value.time
+ self.data.qpos[:] = np.copy(value.qpos)
+ self.data.qvel[:] = np.copy(value.qvel)
+
+ def set_state_from_flattened(self, value):
+ """
+ Set internal mujoco state using flat mjstate array. Should
+ call @forward afterwards to synchronize derived quantities.
+
+ See https://github.com/openai/mujoco-py/blob/4830435a169c1f3e3b5f9b58a7c3d9c39bdf4acb/mujoco_py/mjsimstate.pyx#L54
+ """
+ state = MjSimState.from_flattened(value, self)
+
+ # do this instead of @set_state to avoid extra copy of qpos and qvel
+ self.data.time = state.time
+ self.data.qpos[:] = state.qpos
+ self.data.qvel[:] = state.qvel
+
+ def free(self):
+ # clean up here to prevent memory leaks
+ del self._render_context_offscreen
+ del self.data
+ del self.model
+ del self
+ gc.collect()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/buffers.py b/phantom/submodules/phantom-robosuite/robosuite/utils/buffers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2a1bc20f86a79870d885d76a12739f0263c03fb
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/buffers.py
@@ -0,0 +1,173 @@
+"""
+Collection of Buffer objects with general functionality
+"""
+
+
+import numpy as np
+
+
+class Buffer(object):
+ """
+ Abstract class for different kinds of data buffers. Minimum API should have a "push" and "clear" method
+ """
+
+ def push(self, value):
+ """
+ Pushes a new @value to the buffer
+
+ Args:
+ value: Value to push to the buffer
+ """
+ raise NotImplementedError
+
+ def clear(self):
+ raise NotImplementedError
+
+
+class RingBuffer(Buffer):
+ """
+ Simple RingBuffer object to hold values to average (useful for, e.g.: filtering D component in PID control)
+
+ Note that the buffer object is a 2D numpy array, where each row corresponds to
+ individual entries into the buffer
+
+ Args:
+ dim (int): Size of entries being added. This is, e.g.: the size of a state vector that is to be stored
+ length (int): Size of the ring buffer
+ """
+
+ def __init__(self, dim, length):
+ # Store input args
+ self.dim = dim
+ self.length = length
+
+ # Variable so that initial average values are accurate
+ self._size = 0
+
+ # Save pointer to end of buffer
+ self.ptr = self.length - 1
+
+ # Construct ring buffer
+ self.buf = np.zeros((length, dim))
+
+ def push(self, value):
+ """
+ Pushes a new value into the buffer
+
+ Args:
+ value (int or float or array): Value(s) to push into the array (taken as a single new element)
+ """
+ # Increment pointer, then add value (also increment size if necessary)
+ self.ptr = (self.ptr + 1) % self.length
+ self.buf[self.ptr] = np.array(value)
+ if self._size < self.length:
+ self._size += 1
+
+ def clear(self):
+ """
+ Clears buffer and reset pointer
+ """
+ self.buf = np.zeros((self.length, self.dim))
+ self.ptr = self.length - 1
+ self._size = 0
+
+ @property
+ def current(self):
+ """
+ Gets the most recent value pushed to the buffer
+
+ Returns:
+ float or np.array: Most recent value in buffer
+ """
+ return self.buf[self.ptr]
+
+ @property
+ def average(self):
+ """
+ Gets the average of components in buffer
+
+ Returns:
+ float or np.array: Averaged value of all elements in buffer
+ """
+ return np.mean(self.buf[: self._size], axis=0)
+
+
+class DeltaBuffer(Buffer):
+ """
+ Simple 2-length buffer object to streamline grabbing delta values between "current" and "last" values
+
+ Constructs delta object.
+
+ Args:
+ dim (int): Size of numerical arrays being inputted
+ init_value (None or Iterable): Initial value to fill "last" value with initially.
+ If None (default), last array will be filled with zeros
+ """
+
+ def __init__(self, dim, init_value=None):
+ # Setup delta object
+ self.dim = dim
+ self.last = np.zeros(self.dim) if init_value is None else np.array(init_value)
+ self.current = np.zeros(self.dim)
+
+ def push(self, value):
+ """
+ Pushes a new value into the buffer; current becomes last and @value becomes current
+
+ Args:
+ value (int or float or array): Value(s) to push into the array (taken as a single new element)
+ """
+ self.last = self.current
+ self.current = np.array(value)
+
+ def clear(self):
+ """
+ Clears last and current value
+ """
+ self.last, self.current = np.zeros(self.dim), np.zeros(self.dim)
+
+ @property
+ def delta(self, abs_value=False):
+ """
+ Returns the delta between last value and current value. If abs_value is set to True, then returns
+ the absolute value between the values
+
+ Args:
+ abs_value (bool): Whether to return absolute value or not
+
+ Returns:
+ float or np.array: difference between current and last value
+ """
+ return self.current - self.last if not abs_value else np.abs(self.current - self.last)
+
+ @property
+ def average(self):
+ """
+ Returns the average between the current and last value
+
+ Returns:
+ float or np.array: Averaged value of all elements in buffer
+ """
+ return (self.current + self.last) / 2.0
+
+
+class DelayBuffer(RingBuffer):
+ """
+ Modified RingBuffer that returns delayed values when polled
+ """
+
+ def get_delayed_value(self, delay):
+ """
+ Returns value @delay increments behind most recent value.
+
+ Args:
+ delay (int): How many steps backwards from most recent value to grab value. Note that this should not be
+ greater than the buffer's length
+
+ Returns:
+ np.array: delayed value
+ """
+ # First make sure that the delay is valid
+ assert delay < self.length, "Requested delay must be less than buffer's length!"
+ # Grab delayed value
+ return self.buf[(self.ptr - delay) % self.length]
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/camera_utils.py b/phantom/submodules/phantom-robosuite/robosuite/utils/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..733e65c57c6597ae7b065674ffe23ffad7d93d74
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/camera_utils.py
@@ -0,0 +1,628 @@
+"""
+This module includes:
+
+- Utility classes for modifying sim cameras
+
+- Utility functions for performing common camera operations such as retrieving
+camera matrices and transforming from world to camera frame or vice-versa.
+"""
+import json
+import xml.etree.ElementTree as ET
+
+import h5py
+import numpy as np
+
+import robosuite
+import robosuite.utils.transform_utils as T
+from robosuite.wrappers import DomainRandomizationWrapper, VisualizationWrapper
+
+
+def get_camera_intrinsic_matrix(sim, camera_name, camera_height, camera_width):
+ """
+ Obtains camera intrinsic matrix.
+
+ Args:
+ sim (MjSim): simulator instance
+ camera_name (str): name of camera
+ camera_height (int): height of camera images in pixels
+ camera_width (int): width of camera images in pixels
+ Return:
+ K (np.array): 3x3 camera matrix
+ """
+ cam_id = sim.model.camera_name2id(camera_name)
+ fovy = sim.model.cam_fovy[cam_id]
+ f = 0.5 * camera_height / np.tan(fovy * np.pi / 360)
+ K = np.array([[f, 0, camera_width / 2], [0, f, camera_height / 2], [0, 0, 1]])
+ return K
+
+
+def get_camera_extrinsic_matrix(sim, camera_name):
+ """
+ Returns a 4x4 homogenous matrix corresponding to the camera pose in the
+ world frame. MuJoCo has a weird convention for how it sets up the
+ camera body axis, so we also apply a correction so that the x and y
+ axis are along the camera view and the z axis points along the
+ viewpoint.
+ Normal camera convention: https://docs.opencv.org/2.4/modules/calib3d/doc/camera_calibration_and_3d_reconstruction.html
+
+ Args:
+ sim (MjSim): simulator instance
+ camera_name (str): name of camera
+ Return:
+ R (np.array): 4x4 camera extrinsic matrix
+ """
+ cam_id = sim.model.camera_name2id(camera_name)
+ camera_pos = sim.data.cam_xpos[cam_id]
+ camera_rot = sim.data.cam_xmat[cam_id].reshape(3, 3)
+ R = T.make_pose(camera_pos, camera_rot)
+
+ # IMPORTANT! This is a correction so that the camera axis is set up along the viewpoint correctly.
+ camera_axis_correction = np.array(
+ [[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
+ )
+ R = R @ camera_axis_correction
+ return R
+
+
+def get_camera_transform_matrix(sim, camera_name, camera_height, camera_width):
+ """
+ Camera transform matrix to project from world coordinates to pixel coordinates.
+
+ Args:
+ sim (MjSim): simulator instance
+ camera_name (str): name of camera
+ camera_height (int): height of camera images in pixels
+ camera_width (int): width of camera images in pixels
+ Return:
+ K (np.array): 4x4 camera matrix to project from world coordinates to pixel coordinates
+ """
+ R = get_camera_extrinsic_matrix(sim=sim, camera_name=camera_name)
+ K = get_camera_intrinsic_matrix(
+ sim=sim, camera_name=camera_name, camera_height=camera_height, camera_width=camera_width
+ )
+ K_exp = np.eye(4)
+ K_exp[:3, :3] = K
+
+ # Takes a point in world, transforms to camera frame, and then projects onto image plane.
+ return K_exp @ T.pose_inv(R)
+
+
+def get_camera_segmentation(sim, camera_name, camera_height, camera_width):
+ """
+ Obtains camera segmentation matrix.
+
+ Args:
+ sim (MjSim): simulator instance
+ camera_name (str): name of camera
+ camera_height (int): height of camera images in pixels
+ camera_width (int): width of camera images in pixels
+ Return:
+ im (np.array): 2-channel segmented image where the first contains the
+ geom types and the second contains the geom IDs
+ """
+ return sim.render(camera_name=camera_name, height=camera_height, width=camera_width, segmentation=True)[::-1]
+
+
+def get_real_depth_map(sim, depth_map):
+ """
+ By default, MuJoCo will return a depth map that is normalized in [0, 1]. This
+ helper function converts the map so that the entries correspond to actual distances.
+
+ (see https://github.com/deepmind/dm_control/blob/master/dm_control/mujoco/engine.py#L742)
+
+ Args:
+ sim (MjSim): simulator instance
+ depth_map (np.array): depth map with values normalized in [0, 1] (default depth map
+ returned by MuJoCo)
+ Return:
+ depth_map (np.array): depth map that corresponds to actual distances
+ """
+ # Make sure that depth values are normalized
+ assert np.all(depth_map >= 0.0) and np.all(depth_map <= 1.0)
+ extent = sim.model.stat.extent
+ far = sim.model.vis.map.zfar * extent
+ near = sim.model.vis.map.znear * extent
+ return near / (1.0 - depth_map * (1.0 - near / far))
+
+
+def project_points_from_world_to_camera(points, world_to_camera_transform, camera_height, camera_width):
+ """
+ Helper function to project a batch of points in the world frame
+ into camera pixels using the world to camera transformation.
+
+ Args:
+ points (np.array): 3D points in world frame to project onto camera pixel locations. Should
+ be shape [..., 3].
+ world_to_camera_transform (np.array): 4x4 Tensor to go from robot coordinates to pixel
+ coordinates.
+ camera_height (int): height of the camera image
+ camera_width (int): width of the camera image
+
+ Return:
+ pixels (np.array): projected pixel indices of shape [..., 2]
+ """
+ assert points.shape[-1] == 3 # last dimension must be 3D
+ assert len(world_to_camera_transform.shape) == 2
+ assert world_to_camera_transform.shape[0] == 4 and world_to_camera_transform.shape[1] == 4
+
+ # convert points to homogenous coordinates -> (px, py, pz, 1)
+ ones_pad = np.ones(points.shape[:-1] + (1,))
+ points = np.concatenate((points, ones_pad), axis=-1) # shape [..., 4]
+
+ # batch matrix multiplication of 4 x 4 matrix and 4 x 1 vectors to do robot frame to pixels transform
+ mat_reshape = [1] * len(points.shape[:-1]) + [4, 4]
+ cam_trans = world_to_camera_transform.reshape(mat_reshape) # shape [..., 4, 4]
+ pixels = np.matmul(cam_trans, points[..., None])[..., 0] # shape [..., 4]
+
+ # re-scaling from homogenous coordinates to recover pixel values
+ # (x, y, z) -> (x / z, y / z)
+ pixels = pixels / pixels[..., 2:3]
+ pixels = pixels[..., :2].round().astype(int) # shape [..., 2]
+
+ # swap first and second coordinates to get pixel indices that correspond to (height, width)
+ # and also clip pixels that are out of range of the camera image
+ pixels = np.concatenate(
+ (
+ pixels[..., 1:2].clip(0, camera_height - 1),
+ pixels[..., 0:1].clip(0, camera_width - 1),
+ ),
+ axis=-1,
+ )
+
+ return pixels
+
+
+def transform_from_pixels_to_world(pixels, depth_map, camera_to_world_transform):
+ """
+ Helper function to take a batch of pixel locations and the corresponding depth image
+ and transform these points from the camera frame to the world frame.
+
+ Args:
+ pixels (np.array): pixel coordinates of shape [..., 2]
+ depth_map (np.array): depth images of shape [..., H, W, 1]
+ camera_to_world_transform (np.array): 4x4 Tensor to go from pixel coordinates to world
+ coordinates.
+
+ Return:
+ points (np.array): 3D points in robot frame of shape [..., 3]
+ """
+
+ # make sure leading dimensions are consistent
+ pixels_leading_shape = pixels.shape[:-1]
+ depth_map_leading_shape = depth_map.shape[:-3]
+ assert depth_map_leading_shape == pixels_leading_shape
+
+ # sample from the depth map using the pixel locations with bilinear sampling
+ pixels = pixels.astype(float)
+ im_h, im_w = depth_map.shape[-2:]
+ depth_map_reshaped = depth_map.reshape(-1, im_h, im_w, 1)
+ z = bilinear_interpolate(im=depth_map_reshaped, x=pixels[..., 1:2], y=pixels[..., 0:1])
+ z = z.reshape(*depth_map_leading_shape, 1) # shape [..., 1]
+
+ # form 4D homogenous camera vector to transform - [x * z, y * z, z, 1]
+ # (note that we need to swap the first 2 dimensions of pixels to go from pixel indices
+ # to camera coordinates)
+ cam_pts = [pixels[..., 1:2] * z, pixels[..., 0:1] * z, z, np.ones_like(z)]
+ cam_pts = np.concatenate(cam_pts, axis=-1) # shape [..., 4]
+
+ # batch matrix multiplication of 4 x 4 matrix and 4 x 1 vectors to do camera to robot frame transform
+ mat_reshape = [1] * len(cam_pts.shape[:-1]) + [4, 4]
+ cam_trans = camera_to_world_transform.reshape(mat_reshape) # shape [..., 4, 4]
+ points = np.matmul(cam_trans, cam_pts[..., None])[..., 0] # shape [..., 4]
+ return points[..., :3]
+
+
+def bilinear_interpolate(im, x, y):
+ """
+ Bilinear sampling for pixel coordinates x and y from source image im.
+ Taken from https://stackoverflow.com/questions/12729228/simple-efficient-bilinear-interpolation-of-images-in-numpy-and-python
+ """
+ x = np.asarray(x)
+ y = np.asarray(y)
+
+ x0 = np.floor(x).astype(int)
+ x1 = x0 + 1
+ y0 = np.floor(y).astype(int)
+ y1 = y0 + 1
+
+ x0 = np.clip(x0, 0, im.shape[1] - 1)
+ x1 = np.clip(x1, 0, im.shape[1] - 1)
+ y0 = np.clip(y0, 0, im.shape[0] - 1)
+ y1 = np.clip(y1, 0, im.shape[0] - 1)
+
+ Ia = im[y0, x0]
+ Ib = im[y1, x0]
+ Ic = im[y0, x1]
+ Id = im[y1, x1]
+
+ wa = (x1 - x) * (y1 - y)
+ wb = (x1 - x) * (y - y0)
+ wc = (x - x0) * (y1 - y)
+ wd = (x - x0) * (y - y0)
+
+ return wa * Ia + wb * Ib + wc * Ic + wd * Id
+
+
+class CameraMover:
+ """
+ A class for manipulating a camera.
+
+ WARNING: This class will initially RE-INITIALIZE the environment.
+
+ Args:
+ env (MujocoEnv): Mujoco environment to modify camera
+ camera (str): Which camera to mobilize during playback, e.g.: frontview, agentview, etc.
+ init_camera_pos (None or 3-array): If specified, should be the (x,y,z) global cartesian pos to
+ initialize camera to
+ init_camera_quat (None or 4-array): If specified, should be the (x,y,z,w) global quaternion orientation to
+ initialize camera to
+ """
+
+ def __init__(
+ self,
+ env,
+ camera="frontview",
+ init_camera_pos=None,
+ init_camera_quat=None,
+ ):
+ # Store relevant values and initialize other values
+ self.env = env
+ self.camera = camera
+ self.mover_body_name = f"{self.camera}_cameramover"
+
+ # Get state
+ state = self.env.sim.get_state().flatten()
+
+ # Grab environment xml
+ xml = env.sim.model.get_xml()
+
+ # Modify xml to add mocap to move camera around
+ xml = self.modify_xml_for_camera_movement(xml=xml, camera_name=self.camera)
+
+ # Reset the environment and restore the state
+ self.env.reset_from_xml_string(xml)
+ self.env.sim.reset()
+ self.env.sim.set_state_from_flattened(state)
+ self.env.sim.forward()
+
+ # Set initial camera pose
+ self.set_camera_pose(pos=init_camera_pos, quat=init_camera_quat)
+
+ def set_camera_pose(self, pos=None, quat=None):
+ """
+ Sets the camera pose, which optionally includes position and / or quaternion
+
+ Args:
+ pos (None or 3-array): If specified, should be the (x,y,z) global cartesian pos to set camera to
+ quat (None or 4-array): If specified, should be the (x,y,z,w) global quaternion orientation to set camera to
+ """
+ if pos is not None:
+ self.env.sim.data.set_mocap_pos(self.mover_body_name, pos)
+ if quat is not None:
+ self.env.sim.data.set_mocap_quat(self.mover_body_name, T.convert_quat(quat, to="wxyz"))
+
+ # Make sure changes propagate in sim
+ self.env.sim.forward()
+
+ def get_camera_pose(self):
+ """
+ Grab the current camera pose, which optionally includes position and / or quaternion
+
+ Returns:
+ 2-tuple:
+ - 3-array: (x,y,z) camera global cartesian pos
+ - 4-array: (x,y,z,w) camera global quaternion orientation
+ """
+ # Grab values from sim
+ pos = self.env.sim.data.get_mocap_pos(self.mover_body_name)
+ quat = T.convert_quat(self.env.sim.data.get_mocap_quat(self.mover_body_name), to="xyzw")
+
+ return pos, quat
+
+ def modify_xml_for_camera_movement(self, xml, camera_name):
+ """
+ Cameras in mujoco are 'fixed', so they can't be moved by default.
+ Although it's possible to hack position movement, rotation movement
+ does not work. An alternative is to attach a camera to a mocap body,
+ and move the mocap body.
+
+ This function modifies the camera with name @camera_name in the xml
+ by attaching it to a mocap body that can move around freely. In this
+ way, we can move the camera by moving the mocap body.
+
+ See http://www.mujoco.org/forum/index.php?threads/move-camera.2201/ for
+ further details.
+
+ Args:
+ xml (str): Mujoco sim XML file as a string
+ camera_name (str): Name of camera to tune
+ """
+ tree = ET.fromstring(xml)
+ wb = tree.find("worldbody")
+
+ # find the correct camera
+ camera_elem = None
+ cameras = wb.findall("camera")
+ for camera in cameras:
+ if camera.get("name") == camera_name:
+ camera_elem = camera
+ break
+ assert camera_elem is not None
+
+ # add mocap body
+ mocap = ET.SubElement(wb, "body")
+ mocap.set("name", self.mover_body_name)
+ mocap.set("mocap", "true")
+ mocap.set("pos", camera.get("pos"))
+ mocap.set("quat", camera.get("quat"))
+ new_camera = ET.SubElement(mocap, "camera")
+ new_camera.set("mode", "fixed")
+ new_camera.set("name", camera.get("name"))
+ new_camera.set("pos", "0 0 0")
+
+ # remove old camera element
+ wb.remove(camera_elem)
+
+ return ET.tostring(tree, encoding="utf8").decode("utf8")
+
+ def rotate_camera(self, point, axis, angle):
+ """
+ Rotate the camera view about a direction (in the camera frame).
+
+ Args:
+ point (None or 3-array): (x,y,z) cartesian coordinates about which to rotate camera in camera frame. If None,
+ assumes the point is the current location of the camera
+ axis (3-array): (ax,ay,az) axis about which to rotate camera in camera frame
+ angle (float): how much to rotate about that direction
+
+ Returns:
+ 2-tuple:
+ pos: (x,y,z) updated camera position
+ quat: (x,y,z,w) updated camera quaternion orientation
+ """
+ # current camera rotation + pos
+ camera_pos = np.array(self.env.sim.data.get_mocap_pos(self.mover_body_name))
+ camera_rot = T.quat2mat(T.convert_quat(self.env.sim.data.get_mocap_quat(self.mover_body_name), to="xyzw"))
+
+ # rotate by angle and direction to get new camera rotation
+ rad = np.pi * angle / 180.0
+ R = T.rotation_matrix(rad, axis, point=point)
+ camera_pose = np.zeros((4, 4))
+ camera_pose[:3, :3] = camera_rot
+ camera_pose[:3, 3] = camera_pos
+ camera_pose = camera_pose @ R
+
+ # Update camera pose
+ pos, quat = camera_pose[:3, 3], T.mat2quat(camera_pose[:3, :3])
+ self.set_camera_pose(pos=pos, quat=quat)
+
+ return pos, quat
+
+ def move_camera(self, direction, scale):
+ """
+ Move the camera view along a direction (in the camera frame).
+
+ Args:
+ direction (3-array): direction vector for where to move camera in camera frame
+ scale (float): how much to move along that direction
+ """
+ # current camera rotation + pos
+ camera_pos = np.array(self.env.sim.data.get_mocap_pos(self.mover_body_name))
+ camera_quat = self.env.sim.data.get_mocap_quat(self.mover_body_name)
+ camera_rot = T.quat2mat(T.convert_quat(camera_quat, to="xyzw"))
+
+ # move along camera frame axis and set new position
+ camera_pos += scale * camera_rot.dot(direction)
+ self.set_camera_pose(pos=camera_pos)
+
+ return camera_pos, camera_quat
+
+
+class DemoPlaybackCameraMover(CameraMover):
+ """
+ A class for playing back demonstrations and recording the resulting frames with the flexibility of a mobile camera
+ that can be set manually or panned automatically frame-by-frame
+
+ Note: domain randomization is also supported for playback!
+
+ Args:
+ demo (str): absolute fpath to .hdf5 demo
+ env_config (None or dict): (optional) values to override inferred environment information from demonstration.
+ (e.g.: camera h / w, depths, segmentations, etc...)
+ Any value not specified will be inferred from the extracted demonstration metadata
+ Note that there are some specific arguments that MUST be set a certain way, if any of these values
+ are specified with @env_config, an error will be raised
+ replay_from_actions (bool): If True, will replay demonstration's actions. Otherwise, replays will be hardcoded
+ from the demonstration states
+ visualize_sites (bool): If True, will visualize sites during playback. Note that this CANNOT be paired
+ simultaneously with camera segmentations
+ camera (str): Which camera to mobilize during playback, e.g.: frontview, agentview, etc.
+ init_camera_pos (None or 3-array): If specified, should be the (x,y,z) global cartesian pos to
+ initialize camera to
+ init_camera_quat (None or 4-array): If specified, should be the (x,y,z,w) global quaternion orientation to
+ initialize camera to
+ use_dr (bool): If True, will use domain randomization during playback
+ dr_args (None or dict): If specified, will set the domain randomization wrapper arguments if using dr
+ """
+
+ def __init__(
+ self,
+ demo,
+ env_config=None,
+ replay_from_actions=False,
+ visualize_sites=False,
+ camera="frontview",
+ init_camera_pos=None,
+ init_camera_quat=None,
+ use_dr=False,
+ dr_args=None,
+ ):
+ # Store relevant values and initialize other values
+ self.camera_id = None
+ self.replay_from_actions = replay_from_actions
+ self.states = None
+ self.actions = None
+ self.step = None
+ self.n_steps = None
+ self.current_ep = None
+ self.started = False
+
+ # Load the demo
+ self.f = h5py.File(demo, "r")
+
+ # Extract relevant info
+ env_info = json.loads(self.f["data"].attrs["env_info"])
+
+ # Construct default env arguments
+ default_args = {
+ "has_renderer": False,
+ "has_offscreen_renderer": True,
+ "ignore_done": True,
+ "use_camera_obs": True,
+ "reward_shaping": True,
+ "hard_reset": False,
+ "camera_names": camera,
+ }
+
+ # If custom env_config is specified, make sure that there's no overlap with default args and merge with config
+ if env_config is not None:
+ for k in env_config.keys():
+ assert k not in default_args, f"Key {k} cannot be specified in env_config!"
+ env_info.update(env_config)
+
+ # Merge in default args
+ env_info.update(default_args)
+
+ # Create env
+ env = robosuite.make(**env_info)
+
+ # Optionally wrap with visualization wrapper
+ if visualize_sites:
+ env = VisualizationWrapper(env=self.env)
+
+ # Optionally use domain randomization if specified
+ self.use_dr = use_dr
+ if self.use_dr:
+ default_dr_args = {
+ "seed": 1,
+ "randomize_camera": False,
+ "randomize_every_n_steps": 10,
+ }
+ default_dr_args.update(dr_args)
+ env = DomainRandomizationWrapper(
+ env=self.env,
+ **default_dr_args,
+ )
+
+ # list of all demonstrations episodes
+ self.demos = list(self.f["data"].keys())
+
+ # Run super init
+ super().__init__(
+ env=env,
+ camera=camera,
+ init_camera_pos=init_camera_pos,
+ init_camera_quat=init_camera_quat,
+ )
+
+ # Load episode 0 by default
+ self.load_episode_xml(demo_num=0)
+
+ def load_episode_xml(self, demo_num):
+ """
+ Loads demo episode with specified @demo_num into the simulator.
+
+ Args:
+ demo_num (int): Demonstration number to load
+ """
+ # Grab raw xml file
+ ep = self.demos[demo_num]
+ model_xml = self.f[f"data/{ep}"].attrs["model_file"]
+
+ # Reset environment
+ self.env.reset()
+ xml = self.env.edit_model_xml(model_xml)
+ xml = self.modify_xml_for_camera_movement(xml, camera_name=self.camera)
+ self.env.reset_from_xml_string(xml)
+ self.env.sim.reset()
+
+ # Update camera info
+ self.camera_id = self.env.sim.model.camera_name2id(self.camera)
+
+ # Load states and actions
+ self.states = self.f[f"data/{ep}/states"].value
+ self.actions = np.array(self.f[f"data/{ep}/actions"].value)
+
+ # Set initial state
+ self.env.sim.set_state_from_flattened(self.states[0])
+
+ # Reset step count and set current episode number
+ self.step = 0
+ self.n_steps = len(self.actions)
+ self.current_ep = demo_num
+
+ # Notify user of loaded episode
+ print(f"Loaded episode {demo_num}.")
+
+ def grab_next_frame(self):
+ """
+ Grabs the next frame in the demo sequence by stepping the simulation and returning the resulting value(s)
+
+ Returns:
+ dict: Keyword-mapped np.arrays from the demonstration sequence, corresponding to all image modalities used
+ in the playback environment (e.g.: "image", "depth", "segmentation_instance")
+ """
+ # Make sure the episode isn't completed yet, if so, we load the next episode
+ if self.step == self.n_steps:
+ self.load_episode_xml(demo_num=self.current_ep + 1)
+
+ # Step the environment and grab obs
+ if self.replay_from_actions:
+ obs, _, _, _ = self.env.step(self.actions[self.step])
+ else: # replay from states
+ self.env.sim.set_state_from_flattened(self.states[self.step + 1])
+ if self.use_dr:
+ self.env.step_randomization()
+ self.env.sim.forward()
+ obs = self.env._get_observation()
+
+ # Increment the step counter
+ self.step += 1
+
+ # Return all relevant frames
+ return {k.split(f"{self.camera}_")[-1]: obs[k] for k in obs if self.camera in k}
+
+ def grab_episode_frames(self, demo_num, pan_point=(0, 0, 0.8), pan_axis=(0, 0, 1), pan_rate=0.01):
+ """
+ Playback entire episode @demo_num, while optionally rotating the camera about point @pan_point and
+ axis @pan_axis if @pan_rate > 0
+
+ Args:
+ demo_num (int): Demonstration episode number to load for playback
+ pan_point (3-array): (x,y,z) cartesian coordinates about which to rotate camera in camera frame
+ pan_direction (3-array): (ax,ay,az) axis about which to rotate camera in camera frame
+ pan_rate (float): how quickly to pan camera if not 0
+
+ Returns:
+ dict: Keyword-mapped stacked np.arrays from the demonstration sequence, corresponding to all image
+ modalities used in the playback environment (e.g.: "image", "depth", "segmentation_instance")
+
+ """
+ # First, load env
+ self.load_episode_xml(demo_num=demo_num)
+
+ # Initialize dict to return
+ obs = self.env._get_observation()
+ frames_dict = {k.split(f"{self.camera}_")[-1]: [] for k in obs if self.camera in k}
+
+ # Continue to loop playback steps while there are still frames left in the episode
+ while self.step < self.n_steps:
+ # Take playback step and add frames
+ for k, frame in self.grab_next_frame().items():
+ frames_dict[k].append(frame)
+
+ # Update camera pose
+ self.rotate_camera(point=pan_point, axis=pan_axis, angle=pan_rate)
+
+ # Stack all frames and return
+ return {k: np.stack(frames) for k, frames in frames_dict.items()}
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/control_utils.py b/phantom/submodules/phantom-robosuite/robosuite/utils/control_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..18cc0a6b30cd8011f5267cc39542c86a01867e44
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/control_utils.py
@@ -0,0 +1,236 @@
+import numpy as np
+
+import robosuite.utils.transform_utils as trans
+from robosuite.utils.numba import jit_decorator
+
+
+@jit_decorator
+def nullspace_torques(mass_matrix, nullspace_matrix, initial_joint, joint_pos, joint_vel, joint_kp=10):
+ """
+ For a robot with redundant DOF(s), a nullspace exists which is orthogonal to the remainder of the controllable
+ subspace of the robot's joints. Therefore, an additional secondary objective that does not impact the original
+ controller objective may attempt to be maintained using these nullspace torques.
+
+ This utility function specifically calculates nullspace torques that attempt to maintain a given robot joint
+ positions @initial_joint with zero velocity using proportinal gain @joint_kp
+
+ :Note: @mass_matrix, @nullspace_matrix, @joint_pos, and @joint_vel should reflect the robot's state at the current
+ timestep
+
+ Args:
+ mass_matrix (np.array): 2d array representing the mass matrix of the robot
+ nullspace_matrix (np.array): 2d array representing the nullspace matrix of the robot
+ initial_joint (np.array): Joint configuration to be used for calculating nullspace torques
+ joint_pos (np.array): Current joint positions
+ joint_vel (np.array): Current joint velocities
+ joint_kp (float): Proportional control gain when calculating nullspace torques
+
+ Returns:
+ np.array: nullspace torques
+ """
+
+ # kv calculated below corresponds to critical damping
+ joint_kv = np.sqrt(joint_kp) * 2
+
+ # calculate desired torques based on gains and error
+ pose_torques = np.dot(mass_matrix, (joint_kp * (initial_joint - joint_pos) - joint_kv * joint_vel))
+
+ # map desired torques to null subspace within joint torque actuator space
+ nullspace_torques = np.dot(nullspace_matrix.transpose(), pose_torques)
+ return nullspace_torques
+
+
+@jit_decorator
+def opspace_matrices(mass_matrix, J_full, J_pos, J_ori):
+ """
+ Calculates the relevant matrices used in the operational space control algorithm
+
+ Args:
+ mass_matrix (np.array): 2d array representing the mass matrix of the robot
+ J_full (np.array): 2d array representing the full Jacobian matrix of the robot
+ J_pos (np.array): 2d array representing the position components of the Jacobian matrix of the robot
+ J_ori (np.array): 2d array representing the orientation components of the Jacobian matrix of the robot
+
+ Returns:
+ 4-tuple:
+
+ - (np.array): full lambda matrix (as 2d array)
+ - (np.array): position components of lambda matrix (as 2d array)
+ - (np.array): orientation components of lambda matrix (as 2d array)
+ - (np.array): nullspace matrix (as 2d array)
+ """
+ mass_matrix_inv = np.linalg.inv(mass_matrix)
+
+ # J M^-1 J^T
+ lambda_full_inv = np.dot(np.dot(J_full, mass_matrix_inv), J_full.transpose())
+
+ # Jx M^-1 Jx^T
+ lambda_pos_inv = np.dot(np.dot(J_pos, mass_matrix_inv), J_pos.transpose())
+
+ # Jr M^-1 Jr^T
+ lambda_ori_inv = np.dot(np.dot(J_ori, mass_matrix_inv), J_ori.transpose())
+
+ # take the inverses, but zero out small singular values for stability
+ lambda_full = np.linalg.pinv(lambda_full_inv)
+ lambda_pos = np.linalg.pinv(lambda_pos_inv)
+ lambda_ori = np.linalg.pinv(lambda_ori_inv)
+
+ # nullspace
+ Jbar = np.dot(mass_matrix_inv, J_full.transpose()).dot(lambda_full)
+ nullspace_matrix = np.eye(J_full.shape[-1], J_full.shape[-1]) - np.dot(Jbar, J_full)
+
+ return lambda_full, lambda_pos, lambda_ori, nullspace_matrix
+
+
+@jit_decorator
+def orientation_error(desired, current):
+ """
+ This function calculates a 3-dimensional orientation error vector for use in the
+ impedance controller. It does this by computing the delta rotation between the
+ inputs and converting that rotation to exponential coordinates (axis-angle
+ representation, where the 3d vector is axis * angle).
+ See https://en.wikipedia.org/wiki/Axis%E2%80%93angle_representation for more information.
+ Optimized function to determine orientation error from matrices
+
+ Args:
+ desired (np.array): 2d array representing target orientation matrix
+ current (np.array): 2d array representing current orientation matrix
+
+ Returns:
+ np.array: 2d array representing orientation error as a matrix
+ """
+ rc1 = current[0:3, 0]
+ rc2 = current[0:3, 1]
+ rc3 = current[0:3, 2]
+ rd1 = desired[0:3, 0]
+ rd2 = desired[0:3, 1]
+ rd3 = desired[0:3, 2]
+
+ error = 0.5 * (np.cross(rc1, rd1) + np.cross(rc2, rd2) + np.cross(rc3, rd3))
+
+ return error
+
+
+def set_goal_position(delta, current_position, position_limit=None, set_pos=None):
+ """
+ Calculates and returns the desired goal position, clipping the result accordingly to @position_limits.
+ @delta and @current_position must be specified if a relative goal is requested, else @set_pos must be
+ specified to define a global goal position
+
+ Args:
+ delta (np.array): Desired relative change in position
+ current_position (np.array): Current position
+ position_limit (None or np.array): 2d array defining the (min, max) limits of permissible position goal commands
+ set_pos (None or np.array): If set, will ignore @delta and set the goal position to this value
+
+ Returns:
+ np.array: calculated goal position in absolute coordinates
+
+ Raises:
+ ValueError: [Invalid position_limit shape]
+ """
+ n = len(current_position)
+ if set_pos is not None:
+ goal_position = set_pos
+ else:
+ goal_position = current_position + delta
+
+ if position_limit is not None:
+ if position_limit.shape != (2, n):
+ raise ValueError(
+ "Position limit should be shaped (2,{}) " "but is instead: {}".format(n, position_limit.shape)
+ )
+
+ # Clip goal position
+ goal_position = np.clip(goal_position, position_limit[0], position_limit[1])
+
+ return goal_position
+
+
+def set_goal_orientation(delta, current_orientation, orientation_limit=None, set_ori=None):
+ """
+ Calculates and returns the desired goal orientation, clipping the result accordingly to @orientation_limits.
+ @delta and @current_orientation must be specified if a relative goal is requested, else @set_ori must be
+ an orientation matrix specified to define a global orientation
+
+ Args:
+ delta (np.array): Desired relative change in orientation, in axis-angle form [ax, ay, az]
+ current_orientation (np.array): Current orientation, in rotation matrix form
+ orientation_limit (None or np.array): 2d array defining the (min, max) limits of permissible orientation goal commands
+ set_ori (None or np.array): If set, will ignore @delta and set the goal orientation to this value
+
+ Returns:
+ np.array: calculated goal orientation in absolute coordinates
+
+ Raises:
+ ValueError: [Invalid orientation_limit shape]
+ """
+ # directly set orientation
+ if set_ori is not None:
+ goal_orientation = set_ori
+
+ # otherwise use delta to set goal orientation
+ else:
+ # convert axis-angle value to rotation matrix
+ quat_error = trans.axisangle2quat(delta)
+ rotation_mat_error = trans.quat2mat(quat_error)
+ goal_orientation = np.dot(rotation_mat_error, current_orientation)
+
+ # check for orientation limits
+ if np.array(orientation_limit).any():
+ if orientation_limit.shape != (2, 3):
+ raise ValueError(
+ "Orientation limit should be shaped (2,3) " "but is instead: {}".format(orientation_limit.shape)
+ )
+
+ # Convert to euler angles for clipping
+ euler = trans.mat2euler(goal_orientation)
+
+ # Clip euler angles according to specified limits
+ limited = False
+ for idx in range(3):
+ if orientation_limit[0][idx] < orientation_limit[1][idx]: # Normal angle sector meaning
+ if orientation_limit[0][idx] < euler[idx] < orientation_limit[1][idx]:
+ continue
+ else:
+ limited = True
+ dist_to_lower = euler[idx] - orientation_limit[0][idx]
+ if dist_to_lower > np.pi:
+ dist_to_lower -= 2 * np.pi
+ elif dist_to_lower < -np.pi:
+ dist_to_lower += 2 * np.pi
+
+ dist_to_higher = euler[idx] - orientation_limit[1][idx]
+ if dist_to_lower > np.pi:
+ dist_to_higher -= 2 * np.pi
+ elif dist_to_lower < -np.pi:
+ dist_to_higher += 2 * np.pi
+
+ if dist_to_lower < dist_to_higher:
+ euler[idx] = orientation_limit[0][idx]
+ else:
+ euler[idx] = orientation_limit[1][idx]
+ else: # Inverted angle sector meaning
+ if orientation_limit[0][idx] < euler[idx] or euler[idx] < orientation_limit[1][idx]:
+ continue
+ else:
+ limited = True
+ dist_to_lower = euler[idx] - orientation_limit[0][idx]
+ if dist_to_lower > np.pi:
+ dist_to_lower -= 2 * np.pi
+ elif dist_to_lower < -np.pi:
+ dist_to_lower += 2 * np.pi
+
+ dist_to_higher = euler[idx] - orientation_limit[1][idx]
+ if dist_to_lower > np.pi:
+ dist_to_higher -= 2 * np.pi
+ elif dist_to_lower < -np.pi:
+ dist_to_higher += 2 * np.pi
+
+ if dist_to_lower < dist_to_higher:
+ euler[idx] = orientation_limit[0][idx]
+ else:
+ euler[idx] = orientation_limit[1][idx]
+ if limited:
+ goal_orientation = trans.euler2mat(np.array([euler[0], euler[1], euler[2]]))
+ return goal_orientation
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/errors.py b/phantom/submodules/phantom-robosuite/robosuite/utils/errors.py
new file mode 100644
index 0000000000000000000000000000000000000000..efe0b5cd7119b9a987975ae6028a9e76326d1f8c
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/errors.py
@@ -0,0 +1,22 @@
+class robosuiteError(Exception):
+ """Base class for exceptions in robosuite."""
+
+ pass
+
+
+class XMLError(robosuiteError):
+ """Exception raised for errors related to xml."""
+
+ pass
+
+
+class SimulationError(robosuiteError):
+ """Exception raised for errors during runtime."""
+
+ pass
+
+
+class RandomizationError(robosuiteError):
+ """Exception raised for really really bad RNG."""
+
+ pass
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/input_utils.py b/phantom/submodules/phantom-robosuite/robosuite/utils/input_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..26dee46cce4db7a9d7060064a96a07c8a23fa7ee
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/input_utils.py
@@ -0,0 +1,255 @@
+"""
+Utility functions for grabbing user inputs
+"""
+
+import numpy as np
+
+import robosuite as suite
+import robosuite.utils.transform_utils as T
+from robosuite.devices import *
+from robosuite.models.robots import *
+from robosuite.robots import *
+
+
+def choose_environment():
+ """
+ Prints out environment options, and returns the selected env_name choice
+
+ Returns:
+ str: Chosen environment name
+ """
+ # get the list of all environments
+ envs = sorted(suite.ALL_ENVIRONMENTS)
+
+ # Select environment to run
+ print("Here is a list of environments in the suite:\n")
+
+ for k, env in enumerate(envs):
+ print("[{}] {}".format(k, env))
+ print()
+ try:
+ s = input("Choose an environment to run " + "(enter a number from 0 to {}): ".format(len(envs) - 1))
+ # parse input into a number within range
+ k = min(max(int(s), 0), len(envs))
+ except:
+ k = 0
+ print("Input is not valid. Use {} by default.\n".format(envs[k]))
+
+ # Return the chosen environment name
+ return envs[k]
+
+
+def choose_controller():
+ """
+ Prints out controller options, and returns the requested controller name
+
+ Returns:
+ str: Chosen controller name
+ """
+ # get the list of all controllers
+ controllers_info = suite.controllers.CONTROLLER_INFO
+ controllers = list(suite.ALL_CONTROLLERS)
+
+ # Select controller to use
+ print("Here is a list of controllers in the suite:\n")
+
+ for k, controller in enumerate(controllers):
+ print("[{}] {} - {}".format(k, controller, controllers_info[controller]))
+ print()
+ try:
+ s = input("Choose a controller for the robot " + "(enter a number from 0 to {}): ".format(len(controllers) - 1))
+ # parse input into a number within range
+ k = min(max(int(s), 0), len(controllers) - 1)
+ except:
+ k = 0
+ print("Input is not valid. Use {} by default.".format(controllers)[k])
+
+ # Return chosen controller
+ return controllers[k]
+
+
+def choose_multi_arm_config():
+ """
+ Prints out multi-arm environment configuration options, and returns the requested config name
+
+ Returns:
+ str: Requested multi-arm configuration name
+ """
+ # Get the list of all multi arm configs
+ env_configs = {
+ "Single Arms Opposed": "single-arm-opposed",
+ "Single Arms Parallel": "single-arm-parallel",
+ "Bimanual": "bimanual",
+ }
+
+ # Select environment configuration
+ print("A multi-arm environment was chosen. Here is a list of multi-arm environment configurations:\n")
+
+ for k, env_config in enumerate(list(env_configs)):
+ print("[{}] {}".format(k, env_config))
+ print()
+ try:
+ s = input(
+ "Choose a configuration for this environment "
+ + "(enter a number from 0 to {}): ".format(len(env_configs) - 1)
+ )
+ # parse input into a number within range
+ k = min(max(int(s), 0), len(env_configs))
+ except:
+ k = 0
+ print("Input is not valid. Use {} by default.".format(list(env_configs)[k]))
+
+ # Return requested configuration
+ return list(env_configs.values())[k]
+
+
+def choose_robots(exclude_bimanual=False):
+ """
+ Prints out robot options, and returns the requested robot. Restricts options to single-armed robots if
+ @exclude_bimanual is set to True (False by default)
+
+ Args:
+ exclude_bimanual (bool): If set, excludes bimanual robots from the robot options
+
+ Returns:
+ str: Requested robot name
+ """
+ # Get the list of robots
+ robots = {
+ "Sawyer",
+ "Panda",
+ "Jaco",
+ "Kinova3",
+ "IIWA",
+ "UR5e",
+ }
+
+ # Add Baxter if bimanual robots are not excluded
+ if not exclude_bimanual:
+ robots.add("Baxter")
+
+ # Make sure set is deterministically sorted
+ robots = sorted(robots)
+
+ # Select robot
+ print("Here is a list of available robots:\n")
+
+ for k, robot in enumerate(robots):
+ print("[{}] {}".format(k, robot))
+ print()
+ try:
+ s = input("Choose a robot " + "(enter a number from 0 to {}): ".format(len(robots) - 1))
+ # parse input into a number within range
+ k = min(max(int(s), 0), len(robots))
+ except:
+ k = 0
+ print("Input is not valid. Use {} by default.".format(list(robots)[k]))
+
+ # Return requested robot
+ return list(robots)[k]
+
+
+def input2action(device, robot, active_arm="right", env_configuration=None):
+ """
+ Converts an input from an active device into a valid action sequence that can be fed into an env.step() call
+
+ If a reset is triggered from the device, immediately returns None. Else, returns the appropriate action
+
+ Args:
+ device (Device): A device from which user inputs can be converted into actions. Can be either a Spacemouse or
+ Keyboard device class
+
+ robot (Robot): Which robot we're controlling
+
+ active_arm (str): Only applicable for multi-armed setups (e.g.: multi-arm environments or bimanual robots).
+ Allows inputs to be converted correctly if the control type (e.g.: IK) is dependent on arm choice.
+ Choices are {right, left}
+
+ env_configuration (str or None): Only applicable for multi-armed environments. Allows inputs to be converted
+ correctly if the control type (e.g.: IK) is dependent on the environment setup. Options are:
+ {bimanual, single-arm-parallel, single-arm-opposed}
+
+ Returns:
+ 2-tuple:
+
+ - (None or np.array): Action interpreted from @device including any gripper action(s). None if we get a
+ reset signal from the device
+ - (None or int): 1 if desired close, -1 if desired open gripper state. None if get a reset signal from the
+ device
+
+ """
+ state = device.get_controller_state()
+ # Note: Devices output rotation with x and z flipped to account for robots starting with gripper facing down
+ # Also note that the outputted rotation is an absolute rotation, while outputted dpos is delta pos
+ # Raw delta rotations from neutral user input is captured in raw_drotation (roll, pitch, yaw)
+ dpos, rotation, raw_drotation, grasp, reset = (
+ state["dpos"],
+ state["rotation"],
+ state["raw_drotation"],
+ state["grasp"],
+ state["reset"],
+ )
+
+ # If we're resetting, immediately return None
+ if reset:
+ return None, None
+
+ # Get controller reference
+ controller = robot.controller if not isinstance(robot, Bimanual) else robot.controller[active_arm]
+ gripper_dof = robot.gripper.dof if not isinstance(robot, Bimanual) else robot.gripper[active_arm].dof
+
+ # First process the raw drotation
+ drotation = raw_drotation[[1, 0, 2]]
+ if controller.name == "IK_POSE":
+ # If this is panda, want to swap x and y axis
+ if isinstance(robot.robot_model, Panda):
+ drotation = drotation[[1, 0, 2]]
+ else:
+ # Flip x
+ drotation[0] = -drotation[0]
+ # Scale rotation for teleoperation (tuned for IK)
+ drotation *= 10
+ dpos *= 5
+ # relative rotation of desired from current eef orientation
+ # map to quat
+ drotation = T.mat2quat(T.euler2mat(drotation))
+
+ # If we're using a non-forward facing configuration, need to adjust relative position / orientation
+ if env_configuration == "single-arm-opposed":
+ # Swap x and y for pos and flip x,y signs for ori
+ dpos = dpos[[1, 0, 2]]
+ drotation[0] = -drotation[0]
+ drotation[1] = -drotation[1]
+ if active_arm == "left":
+ # x pos needs to be flipped
+ dpos[0] = -dpos[0]
+ else:
+ # y pos needs to be flipped
+ dpos[1] = -dpos[1]
+
+ # Lastly, map to axis angle form
+ drotation = T.quat2axisangle(drotation)
+
+ elif controller.name == "OSC_POSE":
+ # Flip z
+ drotation[2] = -drotation[2]
+ # Scale rotation for teleoperation (tuned for OSC) -- gains tuned for each device
+ drotation = drotation * 1.5 if isinstance(device, Keyboard) else drotation * 50
+ dpos = dpos * 75 if isinstance(device, Keyboard) else dpos * 125
+ elif controller.name == "OSC_POSITION":
+ dpos = dpos * 75 if isinstance(device, Keyboard) else dpos * 125
+ else:
+ # No other controllers currently supported
+ print("Error: Unsupported controller specified -- Robot must have either an IK or OSC-based controller!")
+
+ # map 0 to -1 (open) and map 1 to 1 (closed)
+ grasp = 1 if grasp else -1
+
+ # Create action based on action space of individual robot
+ if controller.name == "OSC_POSITION":
+ action = np.concatenate([dpos, [grasp] * gripper_dof])
+ else:
+ action = np.concatenate([dpos, drotation, [grasp] * gripper_dof])
+
+ # Return the action and grasp
+ return action, grasp
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/log_utils.py b/phantom/submodules/phantom-robosuite/robosuite/utils/log_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e902fc3788c8a34c60a728d8cf889c66f4dbb370
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/log_utils.py
@@ -0,0 +1,102 @@
+"""
+This file contains utility classes and functions for logging to stdout and stderr
+Adapted from robomimic: https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/log_utils.py
+"""
+import logging
+import os
+import time
+
+from termcolor import colored
+
+import robosuite.macros as macros
+
+LEVEL_COLORS = {
+ logging.DEBUG: "green",
+ logging.INFO: "green",
+ logging.WARNING: "yellow",
+ logging.ERROR: "red",
+ logging.CRITICAL: "red",
+}
+
+FORMAT_STR = {"file": "[robosuite %(levelname)s - %(asctime)s] ", "console": "[robosuite %(levelname)s] "}
+
+MESSAGE_STR = "%(message)s (%(filename)s:%(lineno)d)"
+
+
+class FileFormatter(logging.Formatter):
+ """Formatter class of logging for file logging."""
+
+ FORMATS = {
+ levelno: colored(FORMAT_STR["file"], color, attrs=["bold"]) + MESSAGE_STR
+ for (levelno, color) in LEVEL_COLORS.items()
+ }
+
+ def format(self, record):
+ """Apply custom fomatting on LogRecord object record."""
+ log_fmt = self.FORMATS.get(record.levelno)
+ formatter = logging.Formatter(log_fmt, "%Y-%m-%d %H:%M:%S")
+ return formatter.format(record)
+
+
+class ConsoleFormatter(logging.Formatter):
+ """Formatter class of logging for console logging."""
+
+ FORMATS = {
+ logging.DEBUG: FORMAT_STR["console"] + MESSAGE_STR,
+ logging.INFO: "%(message)s",
+ logging.WARNING: colored(FORMAT_STR["console"], "yellow", attrs=["bold"]) + MESSAGE_STR,
+ logging.ERROR: colored(FORMAT_STR["console"], "red", attrs=["bold"]) + MESSAGE_STR,
+ logging.CRITICAL: colored(FORMAT_STR["console"], "red", attrs=["bold", "reverse"]) + MESSAGE_STR,
+ }
+
+ def format(self, record):
+ """Apply custom fomatting on LogRecord object record."""
+ log_fmt = self.FORMATS.get(record.levelno)
+ formatter = logging.Formatter(log_fmt)
+ return formatter.format(record)
+
+
+class DefaultLogger:
+ """Default logger class in robosuite codebase."""
+
+ def __init__(self, logger_name="robosuite_logs", console_logging_level="INFO", file_logging_level=None):
+ """
+ Args:
+ logger_name (str, optional): logger name. Defaults to "robosuite_logs".
+ console_logging_level (str, optional): logging level for console logging. Defaults to "INFO".
+ file_logging_level (_type_, optional): logging level for file logging. Defaults to None.
+ """
+ self.logger_name = logger_name
+ logger = logging.getLogger(self.logger_name)
+
+ if file_logging_level is not None:
+ time_str = str(time.time()).replace(".", "_")
+ log_file_path = "/tmp/robosuite_{}_{}.log".format(time_str, os.getpid())
+ fh = logging.FileHandler(log_file_path)
+ print(colored("[robosuite]: Saving logs to {}".format(log_file_path), "yellow"))
+ fh.setLevel(logging.getLevelName(file_logging_level))
+ file_formatter = FileFormatter()
+ fh.setFormatter(file_formatter)
+ logger.addHandler(fh)
+
+ if console_logging_level is not None:
+ ch = logging.StreamHandler()
+ ch.setLevel(logging.getLevelName(console_logging_level))
+ console_formatter = ConsoleFormatter()
+ ch.setFormatter(console_formatter)
+ logger.addHandler(ch)
+
+ def get_logger(self):
+ """_summary_
+
+ Returns:
+ DefaultLogger: The retrieved logger whose name equals self.logger_name
+ """
+ logger = logging.getLogger(self.logger_name)
+ return logger
+
+
+ROBOSUITE_DEFAULT_LOGGER = DefaultLogger(
+ console_logging_level=macros.CONSOLE_LOGGING_LEVEL,
+ file_logging_level=macros.FILE_LOGGING_LEVEL,
+).get_logger()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/mjcf_utils.py b/phantom/submodules/phantom-robosuite/robosuite/utils/mjcf_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6af7f647763622f661ed518d16c37add5fa08541
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/mjcf_utils.py
@@ -0,0 +1,855 @@
+# utility functions for manipulating MJCF XML models
+
+import os
+import xml.etree.ElementTree as ET
+from collections.abc import Iterable
+from copy import deepcopy
+from pathlib import Path
+
+import numpy as np
+from PIL import Image
+
+import robosuite
+
+RED = [1, 0, 0, 1]
+GREEN = [0, 1, 0, 1]
+BLUE = [0, 0, 1, 1]
+CYAN = [0, 1, 1, 1]
+ROBOT_COLLISION_COLOR = [0, 0.5, 0, 1]
+MOUNT_COLLISION_COLOR = [0.5, 0.5, 0, 1]
+GRIPPER_COLLISION_COLOR = [0, 0, 0.5, 1]
+OBJECT_COLLISION_COLOR = [0.5, 0, 0, 1]
+ENVIRONMENT_COLLISION_COLOR = [0.5, 0.5, 0, 1]
+SENSOR_TYPES = {
+ "touch",
+ "accelerometer",
+ "velocimeter",
+ "gyro",
+ "force",
+ "torque",
+ "magnetometer",
+ "rangefinder",
+ "jointpos",
+ "jointvel",
+ "tendonpos",
+ "tendonvel",
+ "actuatorpos",
+ "actuatorvel",
+ "actuatorfrc",
+ "ballangvel",
+ "jointlimitpos",
+ "jointlimitvel",
+ "jointlimitfrc",
+ "tendonlimitpos",
+ "tendonlimitvel",
+ "tendonlimitfrc",
+ "framepos",
+ "framequat",
+ "framexaxis",
+ "frameyaxis",
+ "framezaxis",
+ "framelinvel",
+ "frameangvel",
+ "framelinacc",
+ "frameangacc",
+ "subtreecom",
+ "subtreelinvel",
+ "subtreeangmom",
+ "user",
+}
+
+MUJOCO_NAMED_ATTRIBUTES = {
+ "class",
+ "childclass",
+ "name",
+ "objname",
+ "material",
+ "texture",
+ "joint",
+ "joint1",
+ "joint2",
+ "jointinparent",
+ "geom",
+ "geom1",
+ "geom2",
+ "mesh",
+ "fixed",
+ "actuator",
+ "objname",
+ "tendon",
+ "tendon1",
+ "tendon2",
+ "slidesite",
+ "cranksite",
+ "body",
+ "body1",
+ "body2",
+ "hfield",
+ "target",
+ "prefix",
+ "site",
+}
+
+IMAGE_CONVENTION_MAPPING = {
+ "opengl": 1,
+ "opencv": -1,
+}
+
+TEXTURE_FILES = {
+ "WoodRed": "red-wood.png",
+ "WoodGreen": "green-wood.png",
+ "WoodBlue": "blue-wood.png",
+ "WoodLight": "light-wood.png",
+ "WoodDark": "dark-wood.png",
+ "WoodTiles": "wood-tiles.png",
+ "WoodPanels": "wood-varnished-panels.png",
+ "WoodgrainGray": "gray-woodgrain.png",
+ "PlasterCream": "cream-plaster.png",
+ "PlasterPink": "pink-plaster.png",
+ "PlasterYellow": "yellow-plaster.png",
+ "PlasterGray": "gray-plaster.png",
+ "PlasterWhite": "white-plaster.png",
+ "BricksWhite": "white-bricks.png",
+ "Metal": "metal.png",
+ "SteelBrushed": "steel-brushed.png",
+ "SteelScratched": "steel-scratched.png",
+ "Brass": "brass-ambra.png",
+ "Bread": "bread.png",
+ "Can": "can.png",
+ "Ceramic": "ceramic.png",
+ "Cereal": "cereal.png",
+ "Clay": "clay.png",
+ "Dirt": "dirt.png",
+ "Glass": "glass.png",
+ "FeltGray": "gray-felt.png",
+ "Lemon": "lemon.png",
+}
+
+TEXTURES = {
+ texture_name: os.path.join("textures", texture_file) for (texture_name, texture_file) in TEXTURE_FILES.items()
+}
+
+ALL_TEXTURES = TEXTURES.keys()
+
+
+class CustomMaterial(object):
+ """
+ Simple class to instantiate the necessary parameters to define an appropriate texture / material combo
+
+ Instantiates a nested dict holding necessary components for procedurally generating a texture / material combo
+
+ Please see http://www.mujoco.org/book/XMLreference.html#asset for specific details on
+ attributes expected for Mujoco texture / material tags, respectively
+
+ Note that the values in @tex_attrib and @mat_attrib can be in string or array / numerical form.
+
+ Args:
+ texture (None or str or 4-array): Name of texture file to be imported. If a string, should be part of
+ ALL_TEXTURES. If texture is a 4-array, then this argument will be interpreted as an rgba tuple value and
+ a template png will be procedurally generated during object instantiation, with any additional
+ texture / material attributes specified. If None, no file will be linked and no rgba value will be set
+ Note, if specified, the RGBA values are expected to be floats between 0 and 1
+
+ tex_name (str): Name to reference the imported texture
+
+ mat_name (str): Name to reference the imported material
+
+ tex_attrib (dict): Any other optional mujoco texture specifications.
+
+ mat_attrib (dict): Any other optional mujoco material specifications.
+
+ shared (bool): If True, this material should not have any naming prefixes added to all names
+
+ Raises:
+ AssertionError: [Invalid texture]
+ """
+
+ def __init__(
+ self,
+ texture,
+ tex_name,
+ mat_name,
+ tex_attrib=None,
+ mat_attrib=None,
+ shared=False,
+ ):
+ # Check if the desired texture is an rgba value
+ if type(texture) is str:
+ default = False
+ # Verify that requested texture is valid
+ assert texture in ALL_TEXTURES, "Error: Requested invalid texture. Got {}. Valid options are:\n{}".format(
+ texture, ALL_TEXTURES
+ )
+ else:
+ default = True
+ # If specified, this is an rgba value and a default texture is desired; make sure length of rgba array is 4
+ if texture is not None:
+ assert len(texture) == 4, (
+ "Error: Requested default texture. Got array of length {}."
+ "Expected rgba array of length 4.".format(len(texture))
+ )
+
+ # Setup the texture and material attributes
+ self.tex_attrib = {} if tex_attrib is None else tex_attrib.copy()
+ self.mat_attrib = {} if mat_attrib is None else mat_attrib.copy()
+
+ # Add in name values
+ self.name = mat_name
+ self.shared = shared
+ self.tex_attrib["name"] = tex_name
+ self.mat_attrib["name"] = mat_name
+ self.mat_attrib["texture"] = tex_name
+
+ # Loop through all attributes and convert all non-string values into strings
+ for attrib in (self.tex_attrib, self.mat_attrib):
+ for k, v in attrib.items():
+ if type(v) is not str:
+ if isinstance(v, Iterable):
+ attrib[k] = array_to_string(v)
+ else:
+ attrib[k] = str(v)
+
+ # Handle default and non-default cases separately for linking texture patch file locations
+ if not default:
+ # Add in the filepath to texture patch
+ self.tex_attrib["file"] = xml_path_completion(TEXTURES[texture])
+ else:
+ if texture is not None:
+ # Create a texture patch
+ tex = Image.new("RGBA", (100, 100), tuple((np.array(texture) * 255).astype("int")))
+ # Create temp directory if it does not exist
+ save_dir = "/tmp/robosuite_temp_tex"
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
+ # Save this texture patch to the temp directory on disk (MacOS / Linux)
+ fpath = save_dir + "/{}.png".format(tex_name)
+ tex.save(fpath, "PNG")
+ # Link this texture file to the default texture dict
+ self.tex_attrib["file"] = fpath
+
+
+def xml_path_completion(xml_path):
+ """
+ Takes in a local xml path and returns a full path.
+ if @xml_path is absolute, do nothing
+ if @xml_path is not absolute, load xml that is shipped by the package
+
+ Args:
+ xml_path (str): local xml path
+
+ Returns:
+ str: Full (absolute) xml path
+ """
+ if xml_path.startswith("/"):
+ full_path = xml_path
+ else:
+ full_path = os.path.join(robosuite.models.assets_root, xml_path)
+ return full_path
+
+
+def array_to_string(array):
+ """
+ Converts a numeric array into the string format in mujoco.
+
+ Examples:
+ [0, 1, 2] => "0 1 2"
+
+ Args:
+ array (n-array): Array to convert to a string
+
+ Returns:
+ str: String equivalent of @array
+ """
+ return " ".join(["{}".format(x) for x in array])
+
+
+def string_to_array(string):
+ """
+ Converts a array string in mujoco xml to np.array.
+
+ Examples:
+ "0 1 2" => [0, 1, 2]
+
+ Args:
+ string (str): String to convert to an array
+
+ Returns:
+ np.array: Numerical array equivalent of @string
+ """
+ return np.array([float(x) for x in string.strip().split(" ")])
+
+
+def convert_to_string(inp):
+ """
+ Converts any type of {bool, int, float, list, tuple, array, string, np.str_} into an mujoco-xml compatible string.
+ Note that an input string / np.str_ results in a no-op action.
+
+ Args:
+ inp: Input to convert to string
+
+ Returns:
+ str: String equivalent of @inp
+ """
+ if type(inp) in {list, tuple, np.ndarray}:
+ return array_to_string(inp)
+ elif type(inp) in {int, float, bool}:
+ return str(inp).lower()
+ elif type(inp) in {str, np.str_}:
+ return inp
+ else:
+ raise ValueError("Unsupported type received: got {}".format(type(inp)))
+
+
+def set_alpha(node, alpha=0.1):
+ """
+ Sets all a(lpha) field of the rgba attribute to be @alpha
+ for @node and all subnodes
+ used for managing display
+
+ Args:
+ node (ET.Element): Specific node element within XML tree
+ alpha (float): Value to set alpha value of rgba tuple
+ """
+ for child_node in node.findall(".//*[@rgba]"):
+ rgba_orig = string_to_array(child_node.get("rgba"))
+ child_node.set("rgba", array_to_string(list(rgba_orig[0:3]) + [alpha]))
+
+
+def new_element(tag, name, **kwargs):
+ """
+ Creates a new @tag element with attributes specified by @**kwargs.
+
+ Args:
+ tag (str): Type of element to create
+ name (None or str): Name for this element. Should only be None for elements that do not have an explicit
+ name attribute (e.g.: inertial elements)
+ **kwargs: Specified attributes for the new joint
+
+ Returns:
+ ET.Element: new specified xml element
+ """
+ # Name will be set if it's not None
+ if name is not None:
+ kwargs["name"] = name
+ # Loop through all attributes and pop any that are None, otherwise convert them to strings
+ for k, v in kwargs.copy().items():
+ if v is None:
+ kwargs.pop(k)
+ else:
+ kwargs[k] = convert_to_string(v)
+ element = ET.Element(tag, attrib=kwargs)
+ return element
+
+
+def new_joint(name, **kwargs):
+ """
+ Creates a joint tag with attributes specified by @**kwargs.
+
+ Args:
+ name (str): Name for this joint
+ **kwargs: Specified attributes for the new joint
+
+ Returns:
+ ET.Element: new joint xml element
+ """
+ return new_element(tag="joint", name=name, **kwargs)
+
+
+def new_actuator(name, joint, act_type="actuator", **kwargs):
+ """
+ Creates an actuator tag with attributes specified by @**kwargs.
+
+ Args:
+ name (str): Name for this actuator
+ joint (str): type of actuator transmission.
+ see all types here: http://mujoco.org/book/modeling.html#actuator
+ act_type (str): actuator type. Defaults to "actuator"
+ **kwargs: Any additional specified attributes for the new joint
+
+ Returns:
+ ET.Element: new actuator xml element
+ """
+ element = new_element(tag=act_type, name=name, **kwargs)
+ element.set("joint", joint)
+ return element
+
+
+def new_site(name, rgba=RED, pos=(0, 0, 0), size=(0.005,), **kwargs):
+ """
+ Creates a site element with attributes specified by @**kwargs.
+
+ NOTE: With the exception of @name, @pos, and @size, if any arg is set to
+ None, the value will automatically be popped before passing the values
+ to create the appropriate XML
+
+ Args:
+ name (str): Name for this site
+ rgba (4-array): (r,g,b,a) color and transparency. Defaults to solid red.
+ pos (3-array): (x,y,z) 3d position of the site.
+ size (n-array of float): site size (sites are spherical by default).
+ **kwargs: Any additional specified attributes for the new site
+
+ Returns:
+ ET.Element: new site xml element
+ """
+ kwargs["pos"] = pos
+ kwargs["size"] = size
+ kwargs["rgba"] = rgba if rgba is not None else None
+ return new_element(tag="site", name=name, **kwargs)
+
+
+def new_geom(name, type, size, pos=(0, 0, 0), group=0, **kwargs):
+ """
+ Creates a geom element with attributes specified by @**kwargs.
+
+ NOTE: With the exception of @geom_type, @size, and @pos, if any arg is set to
+ None, the value will automatically be popped before passing the values
+ to create the appropriate XML
+
+ Args:
+ name (str): Name for this geom
+ type (str): type of the geom.
+ see all types here: http://mujoco.org/book/modeling.html#geom
+ size (n-array of float): geom size parameters.
+ pos (3-array): (x,y,z) 3d position of the site.
+ group (int): the integrer group that the geom belongs to. useful for
+ separating visual and physical elements.
+ **kwargs: Any additional specified attributes for the new geom
+
+ Returns:
+ ET.Element: new geom xml element
+ """
+ kwargs["type"] = type
+ kwargs["size"] = size
+ kwargs["pos"] = pos
+ kwargs["group"] = group if group is not None else None
+ return new_element(tag="geom", name=name, **kwargs)
+
+
+def new_body(name, pos=(0, 0, 0), **kwargs):
+ """
+ Creates a body element with attributes specified by @**kwargs.
+
+ Args:
+ name (str): Name for this body
+ pos (3-array): (x,y,z) 3d position of the body frame.
+ **kwargs: Any additional specified attributes for the new body
+
+ Returns:
+ ET.Element: new body xml element
+ """
+ kwargs["pos"] = pos
+ return new_element(tag="body", name=name, **kwargs)
+
+
+def new_inertial(pos=(0, 0, 0), mass=None, **kwargs):
+ """
+ Creates a inertial element with attributes specified by @**kwargs.
+
+ Args:
+ pos (3-array): (x,y,z) 3d position of the inertial frame.
+ mass (float): The mass of inertial
+ **kwargs: Any additional specified attributes for the new inertial element
+
+ Returns:
+ ET.Element: new inertial xml element
+ """
+ kwargs["mass"] = mass if mass is not None else None
+ kwargs["pos"] = pos
+ return new_element(tag="inertial", name=None, **kwargs)
+
+
+def get_size(size, size_max, size_min, default_max, default_min):
+ """
+ Helper method for providing a size, or a range to randomize from
+
+ Args:
+ size (n-array): Array of numbers that explicitly define the size
+ size_max (n-array): Array of numbers that define the custom max size from which to randomly sample
+ size_min (n-array): Array of numbers that define the custom min size from which to randomly sample
+ default_max (n-array): Array of numbers that define the default max size from which to randomly sample
+ default_min (n-array): Array of numbers that define the default min size from which to randomly sample
+
+ Returns:
+ np.array: size generated
+
+ Raises:
+ ValueError: [Inconsistent array sizes]
+ """
+ if len(default_max) != len(default_min):
+ raise ValueError(
+ "default_max = {} and default_min = {}".format(str(default_max), str(default_min))
+ + " have different lengths"
+ )
+ if size is not None:
+ if (size_max is not None) or (size_min is not None):
+ raise ValueError("size = {} overrides size_max = {}, size_min = {}".format(size, size_max, size_min))
+ else:
+ if size_max is None:
+ size_max = default_max
+ if size_min is None:
+ size_min = default_min
+ size = np.array([np.random.uniform(size_min[i], size_max[i]) for i in range(len(default_max))])
+ return np.array(size)
+
+
+def add_to_dict(dic, fill_in_defaults=True, default_value=None, **kwargs):
+ """
+ Helper function to add key-values to dictionary @dic where each entry is its own array (list).
+ Args:
+ dic (dict): Dictionary to which new key / value pairs will be added. If the key already exists,
+ will append the value to that key entry
+ fill_in_defaults (bool): If True, will automatically add @default_value to all dictionary entries that are
+ not explicitly specified in @kwargs
+ default_value (any): Default value to fill (None by default)
+
+ Returns:
+ dict: Modified dictionary
+ """
+ # Get keys and length of array for a given entry in dic
+ keys = set(dic.keys())
+ n = len(list(keys)[0]) if keys else 0
+ for k, v in kwargs.items():
+ if k in dic:
+ dic[k].append(v)
+ keys.remove(k)
+ else:
+ dic[k] = [default_value] * n + [v] if fill_in_defaults else [v]
+ # If filling in defaults, fill in remaining default values
+ if fill_in_defaults:
+ for k in keys:
+ dic[k].append(default_value)
+ return dic
+
+
+def add_prefix(
+ root,
+ prefix,
+ tags="default",
+ attribs="default",
+ exclude=None,
+):
+ """
+ Find all element(s) matching the requested @tag, and appends @prefix to all @attributes if they exist.
+
+ Args:
+ root (ET.Element): Root of the xml element tree to start recursively searching through.
+ prefix (str): Prefix to add to all specified attributes
+ tags (str or list of str or set): Tag(s) to search for in this ElementTree. "Default" corresponds to all tags
+ attribs (str or list of str or set): Element attribute(s) to append prefix to. "Default" corresponds
+ to all attributes that reference names
+ exclude (None or function): Filtering function that should take in an ET.Element or a string (attribute) and
+ return True if we should exclude the given element / attribute from having any prefixes added
+ """
+ # Standardize tags and attributes to be a set
+ if tags != "default":
+ tags = {tags} if type(tags) is str else set(tags)
+ if attribs == "default":
+ attribs = MUJOCO_NAMED_ATTRIBUTES
+ attribs = {attribs} if type(attribs) is str else set(attribs)
+
+ # Check the current element for matching conditions
+ if (tags == "default" or root.tag in tags) and (exclude is None or not exclude(root)):
+ for attrib in attribs:
+ v = root.get(attrib, None)
+ # Only add prefix if the attribute exist, the current attribute doesn't already begin with prefix,
+ # and the @exclude filter is either None or returns False
+ if v is not None and not v.startswith(prefix) and (exclude is None or not exclude(v)):
+ root.set(attrib, prefix + v)
+ # Continue recursively searching through the element tree
+ for r in root:
+ add_prefix(root=r, prefix=prefix, tags=tags, attribs=attribs, exclude=exclude)
+
+
+def add_material(root, naming_prefix="", custom_material=None):
+ """
+ Iterates through all element(s) in @root recursively and adds a material / texture to all visual geoms that don't
+ already have a material specified.
+
+ Args:
+ root (ET.Element): Root of the xml element tree to start recursively searching through.
+ naming_prefix (str): Adds this prefix to all material and texture names
+ custom_material (None or CustomMaterial): If specified, will add this material to all visual geoms.
+ Else, will add a default "no-change" material.
+
+ Returns:
+ 4-tuple: (ET.Element, ET.Element, CustomMaterial, bool) (tex_element, mat_element, material, used)
+ corresponding to the added material and whether the material was actually used or not.
+ """
+ # Initialize used as False
+ used = False
+ # First, make sure material is specified
+ if custom_material is None:
+ custom_material = CustomMaterial(
+ texture=None,
+ tex_name="default_tex",
+ mat_name="default_mat",
+ tex_attrib={
+ "type": "cube",
+ "builtin": "flat",
+ "width": 100,
+ "height": 100,
+ "rgb1": np.ones(3),
+ "rgb2": np.ones(3),
+ },
+ )
+ # Else, check to make sure the custom material begins with the specified prefix and that it's unique
+ if not custom_material.name.startswith(naming_prefix) and not custom_material.shared:
+ custom_material.name = naming_prefix + custom_material.name
+ custom_material.tex_attrib["name"] = naming_prefix + custom_material.tex_attrib["name"]
+ custom_material.mat_attrib["name"] = naming_prefix + custom_material.mat_attrib["name"]
+ custom_material.mat_attrib["texture"] = naming_prefix + custom_material.mat_attrib["texture"]
+
+ # Check the current element for matching conditions
+ if root.tag == "geom" and root.get("group", None) == "1" and root.get("material", None) is None:
+ # Add a new material attribute to this geom
+ root.set("material", custom_material.name)
+ # Set used to True
+ used = True
+ # Continue recursively searching through the element tree
+ for r in root:
+ _, _, _, _used = add_material(root=r, naming_prefix=naming_prefix, custom_material=custom_material)
+ # Update used
+ used = used or _used
+ # Lastly, return the new texture and material elements
+ tex_element = new_element(tag="texture", **custom_material.tex_attrib)
+ mat_element = new_element(tag="material", **custom_material.mat_attrib)
+ return tex_element, mat_element, custom_material, used
+
+
+def recolor_collision_geoms(root, rgba, exclude=None):
+ """
+ Iteratively searches through all elements starting with @root to find all geoms belonging to group 0 and set
+ the corresponding rgba value to the specified @rgba argument. Note: also removes any material values for these
+ elements.
+
+ Args:
+ root (ET.Element): Root of the xml element tree to start recursively searching through
+ rgba (4-array): (R, G, B, A) values to assign to all geoms with this group.
+ exclude (None or function): Filtering function that should take in an ET.Element and
+ return True if we should exclude the given element / attribute from having its collision geom impacted.
+ """
+ # Check this body
+ if root.tag == "geom" and root.get("group") in {None, "0"} and (exclude is None or not exclude(root)):
+ root.set("rgba", array_to_string(rgba))
+ root.attrib.pop("material", None)
+
+ # Iterate through all children elements
+ for r in root:
+ recolor_collision_geoms(root=r, rgba=rgba, exclude=exclude)
+
+
+def _element_filter(element, parent):
+ """
+ Default element filter to be used in sort_elements. This will filter for the following groups:
+
+ :`'root_body'`: Top-level body element
+ :`'bodies'`: Any body elements
+ :`'joints'`: Any joint elements
+ :`'actuators'`: Any actuator elements
+ :`'sites'`: Any site elements
+ :`'sensors'`: Any sensor elements
+ :`'contact_geoms'`: Any geoms used for collision (as specified by group 0 (default group) geoms)
+ :`'visual_geoms'`: Any geoms used for visual rendering (as specified by group 1 geoms)
+
+ Args:
+ element (ET.Element): Current XML element that we are filtering
+ parent (ET.Element): Parent XML element for the current element
+
+ Returns:
+ str or None: Assigned filter key for this element. None if no matching filter is found.
+ """
+ # Check for actuator first since this is dependent on the parent element
+ if parent is not None and parent.tag == "actuator":
+ return "actuators"
+ elif element.tag == "joint":
+ # Make sure this is not a tendon (this should not have a "joint", "joint1", or "joint2" attribute specified)
+ if element.get("joint") is None and element.get("joint1") is None:
+ return "joints"
+ elif element.tag == "body":
+ # If the parent of this does not have a tag "body", then this is the top-level body element
+ if parent is None or parent.tag != "body":
+ return "root_body"
+ return "bodies"
+ elif element.tag == "site":
+ return "sites"
+ elif element.tag in SENSOR_TYPES:
+ return "sensors"
+ elif element.tag == "geom":
+ # Only get collision and visual geoms (group 0 / None, or 1, respectively)
+ group = element.get("group")
+ if group in {None, "0", "1"}:
+ return "visual_geoms" if group == "1" else "contact_geoms"
+ else:
+ # If no condition met, return None
+ return None
+
+
+def sort_elements(root, parent=None, element_filter=None, _elements_dict=None):
+ """
+ Utility method to iteratively sort all elements based on @tags. This XML ElementTree will be parsed such that
+ all elements with the same key as returned by @element_filter will be grouped as a list entry in the returned
+ dictionary.
+
+ Args:
+ root (ET.Element): Root of the xml element tree to start recursively searching through
+ parent (ET.Element): Parent of the root node. Default is None (no parent node initially)
+ element_filter (None or function): Function used to filter the incoming elements. Should take in two
+ ET.Elements (current_element, parent_element) and return a string filter_key if the element
+ should be added to the list of values sorted by filter_key, and return None if no value should be added.
+ If no element_filter is specified, defaults to self._element_filter.
+ _elements_dict (dict): Dictionary that gets passed to recursive calls. Should not be modified externally by
+ top-level call.
+
+ Returns:
+ dict: Filtered key-specific lists of the corresponding elements
+ """
+ # Initialize dictionary and element filter if None is set
+ if _elements_dict is None:
+ _elements_dict = {}
+ if element_filter is None:
+ element_filter = _element_filter
+
+ # Parse this element
+ key = element_filter(root, parent)
+ if key is not None:
+ # Initialize new entry in the dict if this is the first time encountering this value, otherwise append
+ if key not in _elements_dict:
+ _elements_dict[key] = [root]
+ else:
+ _elements_dict[key].append(root)
+
+ # Loop through all possible subtrees for this XML recurisvely
+ for r in root:
+ _elements_dict = sort_elements(
+ root=r, parent=root, element_filter=element_filter, _elements_dict=_elements_dict
+ )
+
+ return _elements_dict
+
+
+def find_parent(root, child):
+ """
+ Find the parent element of the specified @child node, recurisvely searching through @root.
+
+ Args:
+ root (ET.Element): Root of the xml element tree to start recursively searching through.
+ child (ET.Element): Child element whose parent is to be found
+
+ Returns:
+ None or ET.Element: Matching parent if found, else None
+ """
+ # Iterate through children (DFS), if the correct child element is found, then return the current root as the parent
+ for r in root:
+ if r == child:
+ return root
+ parent = find_parent(root=r, child=child)
+ if parent is not None:
+ return parent
+ # If we get here, we didn't find anything ):
+ return None
+
+
+def find_elements(root, tags, attribs=None, return_first=True):
+ """
+ Find all element(s) matching the requested @tag and @attributes. If @return_first is True, then will return the
+ first element found matching the criteria specified. Otherwise, will return a list of elements that match the
+ criteria.
+
+ Args:
+ root (ET.Element): Root of the xml element tree to start recursively searching through.
+ tags (str or list of str or set): Tag(s) to search for in this ElementTree.
+ attribs (None or dict of str): Element attribute(s) to check against for a filtered element. A match is
+ considered found only if all attributes match. Each attribute key should have a corresponding value with
+ which to compare against.
+ return_first (bool): Whether to immediately return once the first matching element is found.
+
+ Returns:
+ None or ET.Element or list of ET.Element: Matching element(s) found. Returns None if there was no match.
+ """
+ # Initialize return value
+ elements = None if return_first else []
+
+ # Make sure tags is list
+ tags = [tags] if type(tags) is str else tags
+
+ # Check the current element for matching conditions
+ if root.tag in tags:
+ matching = True
+ if attribs is not None:
+ for k, v in attribs.items():
+ if root.get(k) != v:
+ matching = False
+ break
+ # If all criteria were matched, add this to the solution (or return immediately if specified)
+ if matching:
+ if return_first:
+ return root
+ else:
+ elements.append(root)
+ # Continue recursively searching through the element tree
+ for r in root:
+ if return_first:
+ elements = find_elements(tags=tags, attribs=attribs, root=r, return_first=return_first)
+ if elements is not None:
+ return elements
+ else:
+ found_elements = find_elements(tags=tags, attribs=attribs, root=r, return_first=return_first)
+ pre_elements = deepcopy(elements)
+ if found_elements:
+ elements += found_elements if type(found_elements) is list else [found_elements]
+
+ return elements if elements else None
+
+
+def save_sim_model(sim, fname):
+ """
+ Saves the current model xml from @sim at file location @fname.
+
+ Args:
+ sim (MjSim): XML file to save, in string form
+ fname (str): Absolute filepath to the location to save the file
+ """
+ with open(fname, "w") as f:
+ sim.save(file=f, format="xml")
+
+
+def get_ids(sim, elements, element_type="geom", inplace=False):
+ """
+ Grabs the mujoco IDs for each element in @elements, corresponding to the specified @element_type.
+
+ Args:
+ sim (MjSim): Active mujoco simulation object
+ elements (str or list or dict): Element(s) to convert into IDs. Note that the return type corresponds to
+ @elements type, where each element name is replaced with the ID
+ element_type (str): The type of element to grab ID for. Options are {geom, body, site}
+ inplace (bool): If False, will create a copy of @elements to prevent overwriting the original data structure
+
+ Returns:
+ str or list or dict: IDs corresponding to @elements.
+ """
+ if not inplace:
+ # Copy elements first so we don't write to the underlying object
+ elements = deepcopy(elements)
+ # Choose what to do based on elements type
+ if isinstance(elements, str):
+ # We simply return the value of this single element
+ assert element_type in {
+ "geom",
+ "body",
+ "site",
+ }, f"element_type must be either geom, body, or site. Got: {element_type}"
+ if element_type == "geom":
+ elements = sim.model.geom_name2id(elements)
+ elif element_type == "body":
+ elements = sim.model.body_name2id(elements)
+ else: # site
+ elements = sim.model.site_name2id(elements)
+ elif isinstance(elements, dict):
+ # Iterate over each element in dict and recursively repeat
+ for name, ele in elements:
+ elements[name] = get_ids(sim=sim, elements=ele, element_type=element_type, inplace=True)
+ else: # We assume this is an iterable array
+ assert isinstance(elements, Iterable), "Elements must be iterable for get_id!"
+ elements = [get_ids(sim=sim, elements=ele, element_type=element_type, inplace=True) for ele in elements]
+
+ return elements
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/mjmod.py b/phantom/submodules/phantom-robosuite/robosuite/utils/mjmod.py
new file mode 100644
index 0000000000000000000000000000000000000000..3712e619906214e4221b4937b655b7092aba7611
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/mjmod.py
@@ -0,0 +1,1997 @@
+"""
+Modder classes used for domain randomization. Largely based off of the mujoco-py
+implementation below.
+
+https://github.com/openai/mujoco-py/blob/1fe312b09ae7365f0dd9d4d0e453f8da59fae0bf/mujoco_py/modder.py
+"""
+
+import copy
+import os
+from collections import defaultdict
+
+import numpy as np
+from PIL import Image
+
+import robosuite
+import robosuite.utils.transform_utils as trans
+from robosuite.utils.binding_utils import MjRenderContextOffscreen
+
+
+class BaseModder:
+ """
+ Base class meant to modify simulation attributes mid-sim.
+
+ Using @random_state ensures that sampling here won't be affected
+ by sampling that happens outside of the modders.
+
+ Args:
+ sim (MjSim): simulation object
+
+ random_state (RandomState): instance of np.random.RandomState, specific
+ seed used to randomize these modifications without impacting other
+ numpy seeds / randomizations
+ """
+
+ def __init__(self, sim, random_state=None):
+ self.sim = sim
+ if random_state is None:
+ # default to global RandomState instance
+ self.random_state = np.random.mtrand._rand
+ else:
+ self.random_state = random_state
+
+ def update_sim(self, sim):
+ """
+ Setter function to update internal sim variable
+
+ Args:
+ sim (MjSim): MjSim object
+ """
+ self.sim = sim
+
+ @property
+ def model(self):
+ """
+ Returns:
+ MjModel: Mujoco sim model
+ """
+ # Available for quick convenience access
+ return self.sim.model
+
+
+class LightingModder(BaseModder):
+ """
+ Modder to modify lighting within a Mujoco simulation.
+
+ Args:
+ sim (MjSim): MjSim object
+
+ random_state (RandomState): instance of np.random.RandomState
+
+ light_names (None or list of str): list of lights to use for randomization. If not provided, all
+ lights in the model are randomized.
+
+ randomize_position (bool): If True, randomizes position of lighting
+
+ randomize_direction (bool): If True, randomizes direction of lighting
+
+ randomize_specular (bool): If True, randomizes specular attribute of lighting
+
+ randomize_ambient (bool): If True, randomizes ambient attribute of lighting
+
+ randomize_diffuse (bool): If True, randomizes diffuse attribute of lighting
+
+ randomize_active (bool): If True, randomizes active nature of lighting
+
+ position_perturbation_size (float): Magnitude of position randomization
+
+ direction_perturbation_size (float): Magnitude of direction randomization
+
+ specular_perturbation_size (float): Magnitude of specular attribute randomization
+
+ ambient_perturbation_size (float): Magnitude of ambient attribute randomization
+
+ diffuse_perturbation_size (float): Magnitude of diffuse attribute randomization
+ """
+
+ def __init__(
+ self,
+ sim,
+ random_state=None,
+ light_names=None,
+ randomize_position=True,
+ randomize_direction=True,
+ randomize_specular=True,
+ randomize_ambient=True,
+ randomize_diffuse=True,
+ randomize_active=True,
+ position_perturbation_size=0.1,
+ direction_perturbation_size=0.35, # 20 degrees
+ specular_perturbation_size=0.1,
+ ambient_perturbation_size=0.1,
+ diffuse_perturbation_size=0.1,
+ ):
+ super().__init__(sim, random_state=random_state)
+
+ if light_names is None:
+ light_names = self.sim.model.light_names
+ self.light_names = light_names
+
+ self.randomize_position = randomize_position
+ self.randomize_direction = randomize_direction
+ self.randomize_specular = randomize_specular
+ self.randomize_ambient = randomize_ambient
+ self.randomize_diffuse = randomize_diffuse
+ self.randomize_active = randomize_active
+
+ self.position_perturbation_size = position_perturbation_size
+ self.direction_perturbation_size = direction_perturbation_size
+ self.specular_perturbation_size = specular_perturbation_size
+ self.ambient_perturbation_size = ambient_perturbation_size
+ self.diffuse_perturbation_size = diffuse_perturbation_size
+
+ self.save_defaults()
+
+ def save_defaults(self):
+ """
+ Uses the current MjSim state and model to save default parameter values.
+ """
+ self._defaults = {k: {} for k in self.light_names}
+ for name in self.light_names:
+ self._defaults[name]["pos"] = np.array(self.get_pos(name))
+ self._defaults[name]["dir"] = np.array(self.get_dir(name))
+ self._defaults[name]["specular"] = np.array(self.get_specular(name))
+ self._defaults[name]["ambient"] = np.array(self.get_ambient(name))
+ self._defaults[name]["diffuse"] = np.array(self.get_diffuse(name))
+ self._defaults[name]["active"] = self.get_active(name)
+
+ def restore_defaults(self):
+ """
+ Reloads the saved parameter values.
+ """
+ for name in self.light_names:
+ self.set_pos(name, self._defaults[name]["pos"])
+ self.set_dir(name, self._defaults[name]["dir"])
+ self.set_specular(name, self._defaults[name]["specular"])
+ self.set_ambient(name, self._defaults[name]["ambient"])
+ self.set_diffuse(name, self._defaults[name]["diffuse"])
+ self.set_active(name, self._defaults[name]["active"])
+
+ def randomize(self):
+ """
+ Randomizes all requested lighting values within the sim
+ """
+ for name in self.light_names:
+ if self.randomize_position:
+ self._randomize_position(name)
+
+ if self.randomize_direction:
+ self._randomize_direction(name)
+
+ if self.randomize_specular:
+ self._randomize_specular(name)
+
+ if self.randomize_ambient:
+ self._randomize_ambient(name)
+
+ if self.randomize_diffuse:
+ self._randomize_diffuse(name)
+
+ if self.randomize_active:
+ self._randomize_active(name)
+
+ def _randomize_position(self, name):
+ """
+ Helper function to randomize position of a specific light source
+
+ Args:
+ name (str): Name of the lighting source to randomize for
+ """
+ delta_pos = self.random_state.uniform(
+ low=-self.position_perturbation_size,
+ high=self.position_perturbation_size,
+ size=3,
+ )
+ self.set_pos(
+ name,
+ self._defaults[name]["pos"] + delta_pos,
+ )
+
+ def _randomize_direction(self, name):
+ """
+ Helper function to randomize direction of a specific light source
+
+ Args:
+ name (str): Name of the lighting source to randomize for
+ """
+ # sample a small, random axis-angle delta rotation
+ random_axis, random_angle = trans.random_axis_angle(
+ angle_limit=self.direction_perturbation_size, random_state=self.random_state
+ )
+ random_delta_rot = trans.quat2mat(trans.axisangle2quat(random_axis * random_angle))
+
+ # rotate direction by this delta rotation and set the new direction
+ new_dir = random_delta_rot.dot(self._defaults[name]["dir"])
+ self.set_dir(
+ name,
+ new_dir,
+ )
+
+ def _randomize_specular(self, name):
+ """
+ Helper function to randomize specular attribute of a specific light source
+
+ Args:
+ name (str): Name of the lighting source to randomize for
+ """
+ delta = self.random_state.uniform(
+ low=-self.specular_perturbation_size,
+ high=self.specular_perturbation_size,
+ size=3,
+ )
+ self.set_specular(
+ name,
+ self._defaults[name]["specular"] + delta,
+ )
+
+ def _randomize_ambient(self, name):
+ """
+ Helper function to randomize ambient attribute of a specific light source
+
+ Args:
+ name (str): Name of the lighting source to randomize for
+ """
+ delta = self.random_state.uniform(
+ low=-self.ambient_perturbation_size,
+ high=self.ambient_perturbation_size,
+ size=3,
+ )
+ self.set_ambient(
+ name,
+ self._defaults[name]["ambient"] + delta,
+ )
+
+ def _randomize_diffuse(self, name):
+ """
+ Helper function to randomize diffuse attribute of a specific light source
+
+ Args:
+ name (str): Name of the lighting source to randomize for
+ """
+ delta = self.random_state.uniform(
+ low=-self.diffuse_perturbation_size,
+ high=self.diffuse_perturbation_size,
+ size=3,
+ )
+ self.set_diffuse(
+ name,
+ self._defaults[name]["diffuse"] + delta,
+ )
+
+ def _randomize_active(self, name):
+ """
+ Helper function to randomize active nature of a specific light source
+
+ Args:
+ name (str): Name of the lighting source to randomize for
+ """
+ active = int(self.random_state.uniform() > 0.5)
+ self.set_active(name, active)
+
+ def get_pos(self, name):
+ """
+ Grabs position of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+
+ Returns:
+ np.array: (x,y,z) position of lighting source
+
+ Raises:
+ AssertionError: Invalid light name
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ return self.model.light_pos[lightid]
+
+ def set_pos(self, name, value):
+ """
+ Sets position of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+ value (np.array): (x,y,z) position to set lighting source to
+
+ Raises:
+ AssertionError: Invalid light name
+ AssertionError: Invalid @value
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ value = list(value)
+ assert len(value) == 3, "Expected 3-dim value, got %s" % value
+
+ self.model.light_pos[lightid] = value
+
+ def get_dir(self, name):
+ """
+ Grabs direction of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+
+ Returns:
+ np.array: (x,y,z) direction of lighting source
+
+ Raises:
+ AssertionError: Invalid light name
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ return self.model.light_dir[lightid]
+
+ def set_dir(self, name, value):
+ """
+ Sets direction of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+ value (np.array): (ax,ay,az) direction to set lighting source to
+
+ Raises:
+ AssertionError: Invalid light name
+ AssertionError: Invalid @value
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ value = list(value)
+ assert len(value) == 3, "Expected 3-dim value, got %s" % value
+
+ self.model.light_dir[lightid] = value
+
+ def get_active(self, name):
+ """
+ Grabs active nature of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+
+ Returns:
+ int: Whether light source is active (1) or not (0)
+
+ Raises:
+ AssertionError: Invalid light name
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ return self.model.light_active[lightid]
+
+ def set_active(self, name, value):
+ """
+ Sets active nature of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+ value (int): Whether light source is active (1) or not (0)
+
+ Raises:
+ AssertionError: Invalid light name
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ self.model.light_active[lightid] = value
+
+ def get_specular(self, name):
+ """
+ Grabs specular attribute of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+
+ Returns:
+ np.array: (r,g,b) specular color of lighting source
+
+ Raises:
+ AssertionError: Invalid light name
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ return self.model.light_specular[lightid]
+
+ def set_specular(self, name, value):
+ """
+ Sets specular attribute of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+ value (np.array): (r,g,b) specular color to set lighting source to
+
+ Raises:
+ AssertionError: Invalid light name
+ AssertionError: Invalid @value
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ value = list(value)
+ assert len(value) == 3, "Expected 3-dim value, got %s" % value
+
+ self.model.light_specular[lightid] = value
+
+ def get_ambient(self, name):
+ """
+ Grabs ambient attribute of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+
+ Returns:
+ np.array: (r,g,b) ambient color of lighting source
+
+ Raises:
+ AssertionError: Invalid light name
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ return self.model.light_ambient[lightid]
+
+ def set_ambient(self, name, value):
+ """
+ Sets ambient attribute of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+ value (np.array): (r,g,b) ambient color to set lighting source to
+
+ Raises:
+ AssertionError: Invalid light name
+ AssertionError: Invalid @value
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ value = list(value)
+ assert len(value) == 3, "Expected 3-dim value, got %s" % value
+
+ self.model.light_ambient[lightid] = value
+
+ def get_diffuse(self, name):
+ """
+ Grabs diffuse attribute of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+
+ Returns:
+ np.array: (r,g,b) diffuse color of lighting source
+
+ Raises:
+ AssertionError: Invalid light name
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ return self.model.light_diffuse[lightid]
+
+ def set_diffuse(self, name, value):
+ """
+ Sets diffuse attribute of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+ value (np.array): (r,g,b) diffuse color to set lighting source to
+
+ Raises:
+ AssertionError: Invalid light name
+ AssertionError: Invalid @value
+ """
+ lightid = self.get_lightid(name)
+ assert lightid > -1, "Unkwnown light %s" % name
+
+ value = list(value)
+ assert len(value) == 3, "Expected 3-dim value, got %s" % value
+
+ self.model.light_diffuse[lightid] = value
+
+ def get_lightid(self, name):
+ """
+ Grabs unique id number of a specific light source
+
+ Args:
+ name (str): Name of the lighting source
+
+ Returns:
+ int: id of lighting source. -1 if not found
+ """
+ return self.model.light_name2id(name)
+
+
+class CameraModder(BaseModder):
+ """
+ Modder for modifying camera attributes in mujoco sim
+
+ Args:
+ sim (MjSim): MjSim object
+
+ random_state (None or RandomState): instance of np.random.RandomState
+
+ camera_names (None or list of str): list of camera names to use for randomization. If not provided,
+ all cameras are used for randomization.
+
+ randomize_position (bool): if True, randomize camera position
+
+ randomize_rotation (bool): if True, randomize camera rotation
+
+ randomize_fovy (bool): if True, randomize camera fovy
+
+ position_perturbation_size (float): size of camera position perturbations to each dimension
+
+ rotation_perturbation_size (float): magnitude of camera rotation perturbations in axis-angle.
+ Default corresponds to around 5 degrees.
+
+ fovy_perturbation_size (float): magnitude of camera fovy perturbations (corresponds to focusing)
+
+ Raises:
+ AssertionError: [No randomization selected]
+ """
+
+ def __init__(
+ self,
+ sim,
+ random_state=None,
+ camera_names=None,
+ randomize_position=True,
+ randomize_rotation=True,
+ randomize_fovy=True,
+ position_perturbation_size=0.01,
+ rotation_perturbation_size=0.087,
+ fovy_perturbation_size=5.0,
+ ):
+ super().__init__(sim, random_state=random_state)
+
+ assert randomize_position or randomize_rotation or randomize_fovy
+
+ if camera_names is None:
+ camera_names = self.sim.model.camera_names
+ self.camera_names = camera_names
+
+ self.randomize_position = randomize_position
+ self.randomize_rotation = randomize_rotation
+ self.randomize_fovy = randomize_fovy
+
+ self.position_perturbation_size = position_perturbation_size
+ self.rotation_perturbation_size = rotation_perturbation_size
+ self.fovy_perturbation_size = fovy_perturbation_size
+
+ self.save_defaults()
+
+ def save_defaults(self):
+ """
+ Uses the current MjSim state and model to save default parameter values.
+ """
+ self._defaults = {k: {} for k in self.camera_names}
+ for camera_name in self.camera_names:
+ self._defaults[camera_name]["pos"] = np.array(self.get_pos(camera_name))
+ self._defaults[camera_name]["quat"] = np.array(self.get_quat(camera_name))
+ self._defaults[camera_name]["fovy"] = self.get_fovy(camera_name)
+
+ def restore_defaults(self):
+ """
+ Reloads the saved parameter values.
+ """
+ for camera_name in self.camera_names:
+ self.set_pos(camera_name, self._defaults[camera_name]["pos"])
+ self.set_quat(camera_name, self._defaults[camera_name]["quat"])
+ self.set_fovy(camera_name, self._defaults[camera_name]["fovy"])
+
+ def randomize(self):
+ """
+ Randomizes all requested camera values within the sim
+ """
+ for camera_name in self.camera_names:
+ if self.randomize_position:
+ self._randomize_position(camera_name)
+
+ if self.randomize_rotation:
+ self._randomize_rotation(camera_name)
+
+ if self.randomize_fovy:
+ self._randomize_fovy(camera_name)
+
+ def _randomize_position(self, name):
+ """
+ Helper function to randomize position of a specific camera
+
+ Args:
+ name (str): Name of the camera to randomize for
+ """
+ delta_pos = self.random_state.uniform(
+ low=-self.position_perturbation_size,
+ high=self.position_perturbation_size,
+ size=3,
+ )
+ self.set_pos(
+ name,
+ self._defaults[name]["pos"] + delta_pos,
+ )
+
+ def _randomize_rotation(self, name):
+ """
+ Helper function to randomize orientation of a specific camera
+
+ Args:
+ name (str): Name of the camera to randomize for
+ """
+ # sample a small, random axis-angle delta rotation
+ random_axis, random_angle = trans.random_axis_angle(
+ angle_limit=self.rotation_perturbation_size, random_state=self.random_state
+ )
+ random_delta_rot = trans.quat2mat(trans.axisangle2quat(random_axis * random_angle))
+
+ # compute new rotation and set it
+ base_rot = trans.quat2mat(trans.convert_quat(self._defaults[name]["quat"], to="xyzw"))
+ new_rot = random_delta_rot.T.dot(base_rot)
+ new_quat = trans.convert_quat(trans.mat2quat(new_rot), to="wxyz")
+ self.set_quat(
+ name,
+ new_quat,
+ )
+
+ def _randomize_fovy(self, name):
+ """
+ Helper function to randomize fovy of a specific camera
+
+ Args:
+ name (str): Name of the camera to randomize for
+ """
+ delta_fovy = self.random_state.uniform(
+ low=-self.fovy_perturbation_size,
+ high=self.fovy_perturbation_size,
+ )
+ self.set_fovy(
+ name,
+ self._defaults[name]["fovy"] + delta_fovy,
+ )
+
+ def get_fovy(self, name):
+ """
+ Grabs fovy of a specific camera
+
+ Args:
+ name (str): Name of the camera
+
+ Returns:
+ float: vertical field of view of the camera, expressed in degrees
+
+ Raises:
+ AssertionError: Invalid camera name
+ """
+ camid = self.get_camid(name)
+ assert camid > -1, "Unknown camera %s" % name
+ return self.model.cam_fovy[camid]
+
+ def set_fovy(self, name, value):
+ """
+ Sets fovy of a specific camera
+
+ Args:
+ name (str): Name of the camera
+ value (float): vertical field of view of the camera, expressed in degrees
+
+ Raises:
+ AssertionError: Invalid camera name
+ AssertionError: Invalid value
+ """
+ camid = self.get_camid(name)
+ assert 0 < value < 180
+ assert camid > -1, "Unknown camera %s" % name
+ self.model.cam_fovy[camid] = value
+
+ def get_quat(self, name):
+ """
+ Grabs orientation of a specific camera
+
+ Args:
+ name (str): Name of the camera
+
+ Returns:
+ np.array: (w,x,y,z) orientation of the camera, expressed in quaternions
+
+ Raises:
+ AssertionError: Invalid camera name
+ """
+ camid = self.get_camid(name)
+ assert camid > -1, "Unknown camera %s" % name
+ return self.model.cam_quat[camid]
+
+ def set_quat(self, name, value):
+ """
+ Sets orientation of a specific camera
+
+ Args:
+ name (str): Name of the camera
+ value (np.array): (w,x,y,z) orientation of the camera, expressed in quaternions
+
+ Raises:
+ AssertionError: Invalid camera name
+ AssertionError: Invalid value
+ """
+ value = list(value)
+ assert len(value) == 4, "Expectd value of length 4, instead got %s" % value
+ camid = self.get_camid(name)
+ assert camid > -1, "Unknown camera %s" % name
+ self.model.cam_quat[camid] = value
+
+ def get_pos(self, name):
+ """
+ Grabs position of a specific camera
+
+ Args:
+ name (str): Name of the camera
+
+ Returns:
+ np.array: (x,y,z) position of the camera
+
+ Raises:
+ AssertionError: Invalid camera name
+ """
+ camid = self.get_camid(name)
+ assert camid > -1, "Unknown camera %s" % name
+ return self.model.cam_pos[camid]
+
+ def set_pos(self, name, value):
+ """
+ Sets position of a specific camera
+
+ Args:
+ name (str): Name of the camera
+ value (np.array): (x,y,z) position of the camera
+
+ Raises:
+ AssertionError: Invalid camera name
+ AssertionError: Invalid value
+ """
+ value = list(value)
+ assert len(value) == 3, "Expected value of length 3, instead got %s" % value
+ camid = self.get_camid(name)
+ assert camid > -1
+ self.model.cam_pos[camid] = value
+
+ def get_camid(self, name):
+ """
+ Grabs unique id number of a specific camera
+
+ Args:
+ name (str): Name of the camera
+
+ Returns:
+ int: id of camera. -1 if not found
+ """
+ return self.model.camera_name2id(name)
+
+
+class TextureModder(BaseModder):
+ """
+ Modify textures in model. Example use:
+ sim = MjSim(...)
+ modder = TextureModder(sim)
+ modder.whiten_materials() # ensures materials won't impact colors
+ modder.set_checker('some_geom', (255, 0, 0), (0, 0, 0))
+ modder.rand_all('another_geom')
+
+ Note: in order for the textures to take full effect, you'll need to set
+ the rgba values for all materials to [1, 1, 1, 1], otherwise the texture
+ colors will be modulated by the material colors. Call the
+ `whiten_materials` helper method to set all material colors to white.
+
+ Args:
+ sim (MjSim): MjSim object
+
+ random_state (RandomState): instance of np.random.RandomState
+
+ geom_names ([string]): list of geom names to use for randomization. If not provided,
+ all geoms are used for randomization.
+
+ randomize_local (bool): if True, constrain RGB color variations to be close to the
+ original RGB colors per geom and texture. Otherwise, RGB color values will
+ be sampled uniformly at random.
+
+ randomize_material (bool): if True, randomizes material properties associated with a
+ given texture (reflectance, shininess, specular)
+
+ local_rgb_interpolation (float): determines the size of color variations from
+ the base geom colors when @randomize_local is True.
+
+ local_material_interpolation (float): determines the size of material variations from
+ the base material when @randomize_local and @randomize_material are both True.
+
+ texture_variations (list of str): a list of texture variation strings. Each string
+ must be either 'rgb', 'checker', 'noise', or 'gradient' and corresponds to
+ a specific kind of texture randomization. For each geom that has a material
+ and texture, a random variation from this list is sampled and applied.
+
+ randomize_skybox (bool): if True, apply texture variations to the skybox as well.
+ """
+
+ def __init__(
+ self,
+ sim,
+ random_state=None,
+ geom_names=None,
+ randomize_local=False,
+ randomize_material=False,
+ local_rgb_interpolation=0.1,
+ local_material_interpolation=0.2,
+ texture_variations=("rgb", "checker", "noise", "gradient"),
+ randomize_skybox=True,
+ ):
+ super().__init__(sim, random_state=random_state)
+
+ if geom_names is None:
+ geom_names = self.sim.model.geom_names
+ self.geom_names = geom_names
+
+ self.randomize_local = randomize_local
+ self.randomize_material = randomize_material
+ self.local_rgb_interpolation = local_rgb_interpolation
+ self.local_material_interpolation = local_material_interpolation
+ self.texture_variations = list(texture_variations)
+ self.randomize_skybox = randomize_skybox
+
+ self._all_texture_variation_callbacks = {
+ "rgb": self.rand_rgb,
+ "checker": self.rand_checker,
+ "noise": self.rand_noise,
+ "gradient": self.rand_gradient,
+ }
+ self._texture_variation_callbacks = {
+ k: self._all_texture_variation_callbacks[k] for k in self.texture_variations
+ }
+
+ self.save_defaults()
+
+ def save_defaults(self):
+ """
+ Uses the current MjSim state and model to save default parameter values.
+ """
+ self.textures = [Texture(self.model, i) for i in range(self.model.ntex)]
+ # self._build_tex_geom_map()
+
+ # save copy of original texture bitmaps
+ self._default_texture_bitmaps = [np.array(text.bitmap) for text in self.textures]
+
+ # These matrices will be used to rapidly synthesize
+ # checker pattern bitmaps
+ self._cache_checker_matrices()
+
+ self._defaults = {k: {} for k in self.geom_names}
+ if self.randomize_skybox:
+ self._defaults["skybox"] = {}
+ for name in self.geom_names:
+ if self._check_geom_for_texture(name):
+ # store the texture bitmap for this geom
+ tex_id = self._name_to_tex_id(name)
+ self._defaults[name]["texture"] = self._default_texture_bitmaps[tex_id]
+ # store material properties as well (in tuple (reflectance, shininess, specular) form)
+ self._defaults[name]["material"] = self.get_material(name)
+ else:
+ # store geom color
+ self._defaults[name]["rgb"] = np.array(self.get_geom_rgb(name))
+
+ if self.randomize_skybox:
+ tex_id = self._name_to_tex_id("skybox")
+ self._defaults["skybox"]["texture"] = self._default_texture_bitmaps[tex_id]
+
+ def restore_defaults(self):
+ """
+ Reloads the saved parameter values.
+ """
+ for name in self.geom_names:
+ if self._check_geom_for_texture(name):
+ self.set_texture(name, self._defaults[name]["texture"], perturb=False)
+ self.set_material(name, self._defaults[name]["material"], perturb=False)
+ else:
+ self.set_geom_rgb(name, self._defaults[name]["rgb"])
+
+ if self.randomize_skybox:
+ self.set_texture("skybox", self._defaults["skybox"]["texture"], perturb=False)
+
+ def randomize(self):
+ """
+ Overrides mujoco-py implementation to also randomize color
+ for geoms that have no material.
+ """
+ self.whiten_materials()
+ for name in self.geom_names:
+ if self._check_geom_for_texture(name):
+ # geom has valid texture that can be randomized
+ self._randomize_texture(name)
+ # randomize material if requested
+ if self.randomize_material:
+ self._randomize_material(name)
+ else:
+ # randomize geom color
+ self._randomize_geom_color(name)
+
+ if self.randomize_skybox:
+ self._randomize_texture("skybox")
+
+ def _randomize_geom_color(self, name):
+ """
+ Helper function to randomize color of a specific geom
+
+ Args:
+ name (str): Name of the geom to randomize for
+ """
+ if self.randomize_local:
+ random_color = self.random_state.uniform(0, 1, size=3)
+ rgb = (1.0 - self.local_rgb_interpolation) * self._defaults[name][
+ "rgb"
+ ] + self.local_rgb_interpolation * random_color
+ else:
+ rgb = self.random_state.uniform(0, 1, size=3)
+ self.set_geom_rgb(name, rgb)
+
+ def _randomize_texture(self, name):
+ """
+ Helper function to randomize texture of a specific geom
+
+ Args:
+ name (str): Name of the geom to randomize for
+ """
+ keys = list(self._texture_variation_callbacks.keys())
+ choice = keys[self.random_state.randint(len(keys))]
+ self._texture_variation_callbacks[choice](name)
+
+ def _randomize_material(self, name):
+ """
+ Helper function to randomize material of a specific geom
+
+ Args:
+ name (str): Name of the geom to randomize for
+ """
+ # Return immediately if this is the skybox
+ if name == "skybox":
+ return
+ # Grab material id
+ mat_id = self._name_to_mat_id(name)
+ # Randomize reflectance, shininess, and specular
+ material = self.random_state.uniform(0, 1, size=3) # (reflectance, shininess, specular)
+ self.set_material(name, material, perturb=self.randomize_local)
+
+ def rand_checker(self, name):
+ """
+ Generates a random checker pattern for a specific geom
+
+ Args:
+ name (str): Name of the geom to randomize for
+ """
+ rgb1, rgb2 = self.get_rand_rgb(2)
+ self.set_checker(name, rgb1, rgb2, perturb=self.randomize_local)
+
+ def rand_gradient(self, name):
+ """
+ Generates a random gradient pattern for a specific geom
+
+ Args:
+ name (str): Name of the geom to randomize for
+ """
+ rgb1, rgb2 = self.get_rand_rgb(2)
+ vertical = bool(self.random_state.uniform() > 0.5)
+ self.set_gradient(name, rgb1, rgb2, vertical=vertical, perturb=self.randomize_local)
+
+ def rand_rgb(self, name):
+ """
+ Generates a random RGB color for a specific geom
+
+ Args:
+ name (str): Name of the geom to randomize for
+ """
+ rgb = self.get_rand_rgb()
+ self.set_rgb(name, rgb, perturb=self.randomize_local)
+
+ def rand_noise(self, name):
+ """
+ Generates a random RGB noise pattern for a specific geom
+
+ Args:
+ name (str): Name of the geom to randomize for
+ """
+ fraction = 0.1 + self.random_state.uniform() * 0.8
+ rgb1, rgb2 = self.get_rand_rgb(2)
+ self.set_noise(name, rgb1, rgb2, fraction, perturb=self.randomize_local)
+
+ def whiten_materials(self):
+ """
+ Extends modder.TextureModder to also whiten geom_rgba
+
+ Helper method for setting all material colors to white, otherwise
+ the texture modifications won't take full effect.
+ """
+ for name in self.geom_names:
+ # whiten geom
+ geom_id = self.model.geom_name2id(name)
+ self.model.geom_rgba[geom_id, :] = 1.0
+
+ if self._check_geom_for_texture(name):
+ # whiten material
+ mat_id = self.model.geom_matid[geom_id]
+ self.model.mat_rgba[mat_id, :] = 1.0
+
+ def get_geom_rgb(self, name):
+ """
+ Grabs rgb color of a specific geom
+
+ Args:
+ name (str): Name of the geom
+
+ Returns:
+ np.array: (r,g,b) geom colors
+ """
+ geom_id = self.model.geom_name2id(name)
+ return self.model.geom_rgba[geom_id, :3]
+
+ def set_geom_rgb(self, name, rgb):
+ """
+ Sets rgb color of a specific geom
+
+ Args:
+ name (str): Name of the geom
+ rgb (np.array): (r,g,b) geom colors
+ """
+ geom_id = self.model.geom_name2id(name)
+ self.model.geom_rgba[geom_id, :3] = rgb
+
+ def get_rand_rgb(self, n=1):
+ """
+ Grabs a batch of random rgb tuple combos
+
+ Args:
+ n (int): How many sets of rgb tuples to randomly generate
+
+ Returns:
+ np.array or n-tuple: if n > 1, each tuple entry is a rgb tuple. else, single (r,g,b) array
+ """
+
+ def _rand_rgb():
+ return np.array(self.random_state.uniform(size=3) * 255, dtype=np.uint8)
+
+ if n == 1:
+ return _rand_rgb()
+ else:
+ return tuple(_rand_rgb() for _ in range(n))
+
+ def get_texture(self, name):
+ """
+ Grabs texture of a specific geom
+
+ Args:
+ name (str): Name of the geom
+
+ Returns:
+ Texture: texture associated with the geom
+ """
+ tex_id = self._name_to_tex_id(name)
+ texture = self.textures[tex_id]
+ return texture
+
+ def set_texture(self, name, bitmap, perturb=False):
+ """
+ Sets the bitmap for the texture that corresponds
+ to geom @name.
+
+ If @perturb is True, then use the computed bitmap
+ to perturb the default bitmap slightly, instead
+ of replacing it.
+
+ Args:
+ name (str): Name of the geom
+ bitmap (np.array): 3d-array representing rgb pixel-wise values
+ perturb (bool): Whether to perturb the inputted bitmap or not
+ """
+ bitmap_to_set = self.get_texture(name).bitmap
+ if perturb:
+ bitmap = (1.0 - self.local_rgb_interpolation) * self._defaults[name][
+ "texture"
+ ] + self.local_rgb_interpolation * bitmap
+ bitmap_to_set[:] = bitmap
+ self.upload_texture(name)
+
+ def get_material(self, name):
+ """
+ Grabs material of a specific geom
+
+ Args:
+ name (str): Name of the geom
+
+ Returns:
+ np.array: (reflectance, shininess, specular) material properties associated with the geom
+ """
+ mat_id = self._name_to_mat_id(name)
+ # Material is in tuple form (reflectance, shininess, specular)
+ material = np.array(
+ (self.model.mat_reflectance[mat_id], self.model.mat_shininess[mat_id], self.model.mat_specular[mat_id])
+ )
+ return material
+
+ def set_material(self, name, material, perturb=False):
+ """
+ Sets the material that corresponds to geom @name.
+
+ If @perturb is True, then use the computed material
+ to perturb the default material slightly, instead
+ of replacing it.
+
+ Args:
+ name (str): Name of the geom
+ material (np.array): (reflectance, shininess, specular) material properties associated with the geom
+ perturb (bool): Whether to perturb the inputted material properties or not
+ """
+ mat_id = self._name_to_mat_id(name)
+ if perturb:
+ material = (1.0 - self.local_material_interpolation) * self._defaults[name][
+ "material"
+ ] + self.local_material_interpolation * material
+ self.model.mat_reflectance[mat_id] = material[0]
+ self.model.mat_shininess[mat_id] = material[1]
+ self.model.mat_specular[mat_id] = material[2]
+
+ def get_checker_matrices(self, name):
+ """
+ Grabs checker pattern matrix associated with @name.
+
+ Args:
+ name (str): Name of geom
+
+ Returns:
+ np.array: 3d-array representing rgb checker pattern
+ """
+ tex_id = self._name_to_tex_id(name)
+ return self._texture_checker_mats[tex_id]
+
+ def set_checker(self, name, rgb1, rgb2, perturb=False):
+ """
+ Use the two checker matrices to create a checker
+ pattern from the two colors, and set it as
+ the texture for geom @name.
+
+ Args:
+ name (str): Name of geom
+ rgb1 (3-array): (r,g,b) value for one half of checker pattern
+ rgb2 (3-array): (r,g,b) value for other half of checker pattern
+ perturb (bool): Whether to perturb the resulting checker pattern or not
+ """
+ cbd1, cbd2 = self.get_checker_matrices(name)
+ rgb1 = np.asarray(rgb1).reshape([1, 1, -1])
+ rgb2 = np.asarray(rgb2).reshape([1, 1, -1])
+ bitmap = rgb1 * cbd1 + rgb2 * cbd2
+
+ self.set_texture(name, bitmap, perturb=perturb)
+
+ def set_gradient(self, name, rgb1, rgb2, vertical=True, perturb=False):
+ """
+ Creates a linear gradient from rgb1 to rgb2.
+
+ Args:
+ name (str): Name of geom
+ rgb1 (3-array): start color
+ rgb2 (3- array): end color
+ vertical (bool): if True, the gradient in the positive
+ y-direction, if False it's in the positive x-direction.
+ perturb (bool): Whether to perturb the resulting gradient pattern or not
+ """
+ # NOTE: MuJoCo's gradient uses a sigmoid. Here we simplify
+ # and just use a linear gradient... We could change this
+ # to just use a tanh-sigmoid if needed.
+ bitmap = self.get_texture(name).bitmap
+ h, w = bitmap.shape[:2]
+ if vertical:
+ p = np.tile(np.linspace(0, 1, h)[:, None], (1, w))
+ else:
+ p = np.tile(np.linspace(0, 1, w), (h, 1))
+
+ new_bitmap = np.zeros_like(bitmap)
+ for i in range(3):
+ new_bitmap[..., i] = rgb2[i] * p + rgb1[i] * (1.0 - p)
+
+ self.set_texture(name, new_bitmap, perturb=perturb)
+
+ def set_rgb(self, name, rgb, perturb=False):
+ """
+ Just set the texture bitmap for geom @name
+ to a constant rgb value.
+
+ Args:
+ name (str): Name of geom
+ rgb (3-array): desired (r,g,b) color
+ perturb (bool): Whether to perturb the resulting color pattern or not
+ """
+ bitmap = self.get_texture(name).bitmap
+ new_bitmap = np.zeros_like(bitmap)
+ new_bitmap[..., :] = np.asarray(rgb)
+
+ self.set_texture(name, new_bitmap, perturb=perturb)
+
+ def set_noise(self, name, rgb1, rgb2, fraction=0.9, perturb=False):
+ """
+ Sets the texture bitmap for geom @name to a noise pattern
+
+ Args:
+ name (str): name of geom
+ rgb1 (3-array): background color
+ rgb2 (3-array): color of random noise foreground color
+ fraction (float): fraction of pixels with foreground color
+ perturb (bool): Whether to perturb the resulting color pattern or not
+ """
+ bitmap = self.get_texture(name).bitmap
+ h, w = bitmap.shape[:2]
+ mask = self.random_state.uniform(size=(h, w)) < fraction
+
+ new_bitmap = np.zeros_like(bitmap)
+ new_bitmap[..., :] = np.asarray(rgb1)
+ new_bitmap[mask, :] = np.asarray(rgb2)
+
+ self.set_texture(name, new_bitmap, perturb=perturb)
+
+ def upload_texture(self, name, device_id=0):
+ """
+ Uploads the texture to the GPU so it's available in the rendering.
+
+ Args:
+ name (str): name of geom
+ """
+ texture = self.get_texture(name)
+ if self.sim._render_context_offscreen is None:
+ render_context = MjRenderContextOffscreen(self.sim, device_id)
+ render_context.upload_texture(texture.id)
+
+ def _check_geom_for_texture(self, name):
+ """
+ Helper function to determined if the geom @name has
+ an assigned material and that the material has
+ an assigned texture.
+
+ Args:
+ name (str): name of geom
+
+ Returns:
+ bool: True if specific geom has both material and texture associated, else False
+ """
+ geom_id = self.model.geom_name2id(name)
+ mat_id = self.model.geom_matid[geom_id]
+ if mat_id < 0:
+ return False
+ tex_id = self.model.mat_texid[mat_id]
+ if tex_id < 0:
+ return False
+ return True
+
+ def _name_to_tex_id(self, name):
+ """
+ Helper function to get texture id from geom name.
+
+ Args:
+ name (str): name of geom
+
+ Returns:
+ int: id of texture associated with geom
+
+ Raises:
+ AssertionError: [No texture associated with geom]
+ """
+
+ # handle skybox separately
+ if name == "skybox":
+ skybox_tex_id = -1
+ for tex_id in range(self.model.ntex):
+ skybox_textype = 2
+ if self.model.tex_type[tex_id] == skybox_textype:
+ skybox_tex_id = tex_id
+ assert skybox_tex_id >= 0
+ return skybox_tex_id
+
+ assert self._check_geom_for_texture(name)
+ geom_id = self.model.geom_name2id(name)
+ mat_id = self.model.geom_matid[geom_id]
+ tex_id = self.model.mat_texid[mat_id]
+ return tex_id
+
+ def _name_to_mat_id(self, name):
+ """
+ Helper function to get material id from geom name.
+
+ Args:
+ name (str): name of geom
+
+ Returns:
+ int: id of material associated with geom
+
+ Raises:
+ ValueError: [No material associated with skybox]
+ AssertionError: [No material associated with geom]
+ """
+
+ # handle skybox separately
+ if name == "skybox":
+ raise ValueError("Error: skybox has no material!")
+
+ assert self._check_geom_for_texture(name)
+ geom_id = self.model.geom_name2id(name)
+ mat_id = self.model.geom_matid[geom_id]
+ return mat_id
+
+ def _cache_checker_matrices(self):
+ """
+ Cache two matrices of the form [[1, 0, 1, ...],
+ [0, 1, 0, ...],
+ ...]
+ and [[0, 1, 0, ...],
+ [1, 0, 1, ...],
+ ...]
+ for each texture. To use for fast creation of checkerboard patterns
+ """
+ self._texture_checker_mats = []
+ for tex_id in range(self.model.ntex):
+ texture = self.textures[tex_id]
+ h, w = texture.bitmap.shape[:2]
+ self._texture_checker_mats.append(self._make_checker_matrices(h, w))
+
+ def _make_checker_matrices(self, h, w):
+ """
+ Helper function to quickly generate binary matrices used to create checker patterns
+
+ Args:
+ h (int): Desired height of matrices
+ w (int): Desired width of matrices
+
+ Returns:
+ 2-tuple:
+
+ - (np.array): 2d-array representing first half of checker matrix
+ - (np.array): 2d-array representing second half of checker matrix
+ """
+ re = np.r_[((w + 1) // 2) * [0, 1]]
+ ro = np.r_[((w + 1) // 2) * [1, 0]]
+ cbd1 = np.expand_dims(np.row_stack(((h + 1) // 2) * [re, ro]), -1)[:h, :w]
+ cbd2 = np.expand_dims(np.row_stack(((h + 1) // 2) * [ro, re]), -1)[:h, :w]
+ return cbd1, cbd2
+
+
+# From mjtTexture
+MJT_TEXTURE_ENUM = ["2d", "cube", "skybox"]
+
+
+class Texture:
+ """
+ Helper class for operating on the MuJoCo textures.
+
+ Args:
+ model (MjModel): Mujoco sim model
+ tex_id (int): id of specific texture in mujoco sim
+ """
+
+ __slots__ = ["id", "type", "height", "width", "tex_adr", "tex_rgb"]
+
+ def __init__(self, model, tex_id):
+ self.id = tex_id
+ self.type = MJT_TEXTURE_ENUM[model.tex_type[tex_id]]
+ self.height = model.tex_height[tex_id]
+ self.width = model.tex_width[tex_id]
+ self.tex_adr = model.tex_adr[tex_id]
+ self.tex_rgb = model.tex_rgb
+
+ @property
+ def bitmap(self):
+ """
+ Grabs color bitmap associated with this texture from the mujoco sim.
+
+ Returns:
+ np.array: 3d-array representing the rgb texture bitmap
+ """
+ size = self.height * self.width * 3
+ data = self.tex_rgb[self.tex_adr : self.tex_adr + size]
+ return data.reshape((self.height, self.width, 3))
+
+
+class DynamicsModder(BaseModder):
+ """
+ Modder for various dynamics properties of the mujoco model, such as friction, damping, etc.
+ This can be used to modify parameters stored in MjModel (ie friction, damping, etc.) as
+ well as optimizer parameters stored in PyMjOption (i.e.: medium density, viscosity, etc.)
+ To modify a parameter, use the parameter to be changed as a keyword argument to
+ self.mod and the new value as the value for that argument. Supports arbitrary many
+ modifications in a single step. Example use:
+ sim = MjSim(...)
+ modder = DynamicsModder(sim)
+ modder.mod("element1_name", "attr1", new_value1)
+ modder.mod("element2_name", "attr2", new_value2)
+ ...
+ modder.update()
+
+ NOTE: It is necessary to perform modder.update() after performing all modifications to make sure
+ the changes are propagated
+
+ NOTE: A full list of supported randomizable parameters can be seen by calling modder.dynamics_parameters
+
+ NOTE: When modifying parameters belonging to MjModel.opt (e.g.: density, viscosity), no name should
+ be specified (set it as None in mod(...)). This is because opt does not have a name attribute
+ associated with it
+
+ Args:
+ sim (MjSim): Mujoco sim instance
+
+ random_state (RandomState): instance of np.random.RandomState
+
+ randomize_density (bool): If True, randomizes global medium density
+
+ randomize_viscosity (bool): If True, randomizes global medium viscosity
+
+ density_perturbation_ratio (float): Relative (fraction) magnitude of default density randomization
+
+ viscosity_perturbation_ratio: Relative (fraction) magnitude of default viscosity randomization
+
+ body_names (None or list of str): list of bodies to use for randomization. If not provided, all
+ bodies in the model are randomized.
+
+ randomize_position (bool): If True, randomizes body positions
+
+ randomize_quaternion (bool): If True, randomizes body quaternions
+
+ randomize_inertia (bool): If True, randomizes body inertias (only applicable for non-zero mass bodies)
+
+ randomize_mass (bool): If True, randomizes body masses (only applicable for non-zero mass bodies)
+
+ position_perturbation_size (float): Magnitude of body position randomization
+
+ quaternion_perturbation_size (float): Magnitude of body quaternion randomization (angle in radians)
+
+ inertia_perturbation_ratio (float): Relative (fraction) magnitude of body inertia randomization
+
+ mass_perturbation_ratio (float): Relative (fraction) magnitude of body mass randomization
+
+ geom_names (None or list of str): list of geoms to use for randomization. If not provided, all
+ geoms in the model are randomized.
+
+ randomize_friction (bool): If True, randomizes geom frictions
+
+ randomize_solref (bool): If True, randomizes geom solrefs
+
+ randomize_solimp (bool): If True, randomizes geom solimps
+
+ friction_perturbation_ratio (float): Relative (fraction) magnitude of geom friction randomization
+
+ solref_perturbation_ratio (float): Relative (fraction) magnitude of geom solref randomization
+
+ solimp_perturbation_ratio (float): Relative (fraction) magnitude of geom solimp randomization
+
+ joint_names (None or list of str): list of joints to use for randomization. If not provided, all
+ joints in the model are randomized.
+
+ randomize_stiffness (bool): If True, randomizes joint stiffnesses
+
+ randomize_frictionloss (bool): If True, randomizes joint frictionlosses
+
+ randomize_damping (bool): If True, randomizes joint dampings
+
+ randomize_armature (bool): If True, randomizes joint armatures
+
+ stiffness_perturbation_ratio (float): Relative (fraction) magnitude of joint stiffness randomization
+
+ frictionloss_perturbation_size (float): Magnitude of joint frictionloss randomization
+
+ damping_perturbation_size (float): Magnitude of joint damping randomization
+
+ armature_perturbation_size (float): Magnitude of joint armature randomization
+ """
+
+ def __init__(
+ self,
+ sim,
+ random_state=None,
+ # Opt parameters
+ randomize_density=True,
+ randomize_viscosity=True,
+ density_perturbation_ratio=0.1,
+ viscosity_perturbation_ratio=0.1,
+ # Body parameters
+ body_names=None,
+ randomize_position=True,
+ randomize_quaternion=True,
+ randomize_inertia=True,
+ randomize_mass=True,
+ position_perturbation_size=0.02,
+ quaternion_perturbation_size=0.02,
+ inertia_perturbation_ratio=0.02,
+ mass_perturbation_ratio=0.02,
+ # Geom parameters
+ geom_names=None,
+ randomize_friction=True,
+ randomize_solref=True,
+ randomize_solimp=True,
+ friction_perturbation_ratio=0.1,
+ solref_perturbation_ratio=0.1,
+ solimp_perturbation_ratio=0.1,
+ # Joint parameters
+ joint_names=None,
+ randomize_stiffness=True,
+ randomize_frictionloss=True,
+ randomize_damping=True,
+ randomize_armature=True,
+ stiffness_perturbation_ratio=0.1,
+ frictionloss_perturbation_size=0.05,
+ damping_perturbation_size=0.01,
+ armature_perturbation_size=0.01,
+ ):
+ super().__init__(sim=sim, random_state=random_state)
+
+ # Setup relevant values
+ self.dummy_bodies = set()
+ # Find all bodies that don't have any mass associated with them
+ for body_name in self.sim.model.body_names:
+ body_id = self.sim.model.body_name2id(body_name)
+ if self.sim.model.body_mass[body_id] == 0:
+ self.dummy_bodies.add(body_name)
+
+ # Get all values to randomize
+ self.body_names = list(self.sim.model.body_names) if body_names is None else body_names
+ self.geom_names = list(self.sim.model.geom_names) if geom_names is None else geom_names
+ self.joint_names = list(self.sim.model.joint_names) if joint_names is None else joint_names
+
+ # Setup randomization settings
+ # Each dynamics randomization group has its set of randomizable parameters, each of which has
+ # its own settings ["randomize": whether its actively being randomized, "perturbation": the (potentially)
+ # relative magnitude of the randomization to use, "type": either "ratio" or "size" (relative or absolute
+ # perturbations), and "clip": (low, high) values to clip the final perturbed value by]
+ self.opt_randomizations = {
+ "density": {
+ "randomize": randomize_density,
+ "perturbation": density_perturbation_ratio,
+ "type": "ratio",
+ "clip": (0.0, np.inf),
+ },
+ "viscosity": {
+ "randomize": randomize_viscosity,
+ "perturbation": viscosity_perturbation_ratio,
+ "type": "ratio",
+ "clip": (0.0, np.inf),
+ },
+ }
+
+ self.body_randomizations = {
+ "position": {
+ "randomize": randomize_position,
+ "perturbation": position_perturbation_size,
+ "type": "size",
+ "clip": (-np.inf, np.inf),
+ },
+ "quaternion": {
+ "randomize": randomize_quaternion,
+ "perturbation": quaternion_perturbation_size,
+ "type": "size",
+ "clip": (-np.inf, np.inf),
+ },
+ "inertia": {
+ "randomize": randomize_inertia,
+ "perturbation": inertia_perturbation_ratio,
+ "type": "ratio",
+ "clip": (0.0, np.inf),
+ },
+ "mass": {
+ "randomize": randomize_mass,
+ "perturbation": mass_perturbation_ratio,
+ "type": "ratio",
+ "clip": (0.0, np.inf),
+ },
+ }
+
+ self.geom_randomizations = {
+ "friction": {
+ "randomize": randomize_friction,
+ "perturbation": friction_perturbation_ratio,
+ "type": "ratio",
+ "clip": (0.0, np.inf),
+ },
+ "solref": {
+ "randomize": randomize_solref,
+ "perturbation": solref_perturbation_ratio,
+ "type": "ratio",
+ "clip": (0.0, 1.0),
+ },
+ "solimp": {
+ "randomize": randomize_solimp,
+ "perturbation": solimp_perturbation_ratio,
+ "type": "ratio",
+ "clip": (0.0, np.inf),
+ },
+ }
+
+ self.joint_randomizations = {
+ "stiffness": {
+ "randomize": randomize_stiffness,
+ "perturbation": stiffness_perturbation_ratio,
+ "type": "ratio",
+ "clip": (0.0, np.inf),
+ },
+ "frictionloss": {
+ "randomize": randomize_frictionloss,
+ "perturbation": frictionloss_perturbation_size,
+ "type": "size",
+ "clip": (0.0, np.inf),
+ },
+ "damping": {
+ "randomize": randomize_damping,
+ "perturbation": damping_perturbation_size,
+ "type": "size",
+ "clip": (0.0, np.inf),
+ },
+ "armature": {
+ "randomize": randomize_armature,
+ "perturbation": armature_perturbation_size,
+ "type": "size",
+ "clip": (0.0, np.inf),
+ },
+ }
+
+ # Store defaults so we don't loss track of the original (non-perturbed) values
+ self.opt_defaults = None
+ self.body_defaults = None
+ self.geom_defaults = None
+ self.joint_defaults = None
+ self.save_defaults()
+
+ def save_defaults(self):
+ """
+ Grabs the current values for all parameters in sim and stores them as default values
+ """
+ self.opt_defaults = {
+ None: { # no name associated with the opt parameters
+ "density": self.sim.model.opt.density,
+ "viscosity": self.sim.model.opt.viscosity,
+ }
+ }
+
+ self.body_defaults = {}
+ for body_name in self.sim.model.body_names:
+ body_id = self.sim.model.body_name2id(body_name)
+ self.body_defaults[body_name] = {
+ "position": np.array(self.sim.model.body_pos[body_id]),
+ "quaternion": np.array(self.sim.model.body_quat[body_id]),
+ "inertia": np.array(self.sim.model.body_inertia[body_id]),
+ "mass": self.sim.model.body_mass[body_id],
+ }
+
+ self.geom_defaults = {}
+ for geom_name in self.sim.model.geom_names:
+ geom_id = self.sim.model.geom_name2id(geom_name)
+ self.geom_defaults[geom_name] = {
+ "friction": np.array(self.sim.model.geom_friction[geom_id]),
+ "solref": np.array(self.sim.model.geom_solref[geom_id]),
+ "solimp": np.array(self.sim.model.geom_solimp[geom_id]),
+ }
+
+ self.joint_defaults = {}
+ for joint_name in self.sim.model.joint_names:
+ joint_id = self.sim.model.joint_name2id(joint_name)
+ dof_idx = [i for i, v in enumerate(self.sim.model.dof_jntid) if v == joint_id]
+ self.joint_defaults[joint_name] = {
+ "stiffness": self.sim.model.jnt_stiffness[joint_id],
+ "frictionloss": np.array(self.sim.model.dof_frictionloss[dof_idx]),
+ "damping": np.array(self.sim.model.dof_damping[dof_idx]),
+ "armature": np.array(self.sim.model.dof_armature[dof_idx]),
+ }
+
+ def restore_defaults(self):
+ """
+ Restores the default values curently saved in this modder
+ """
+ # Loop through all defaults and set the default value in sim
+ for group_defaults in (self.opt_defaults, self.body_defaults, self.geom_defaults, self.joint_defaults):
+ for name, defaults in group_defaults.items():
+ for attr, default_val in defaults.items():
+ self.mod(name=name, attr=attr, val=default_val)
+
+ # Make sure changes propagate in sim
+ self.update()
+
+ def randomize(self):
+ """
+ Randomizes all enabled dynamics parameters in the simulation
+ """
+ for group_defaults, group_randomizations, group_randomize_names in zip(
+ (self.opt_defaults, self.body_defaults, self.geom_defaults, self.joint_defaults),
+ (self.opt_randomizations, self.body_randomizations, self.geom_randomizations, self.joint_randomizations),
+ ([None], self.body_names, self.geom_names, self.joint_names),
+ ):
+ for name in group_randomize_names:
+ # Randomize all parameters associated with this element
+ for attr, default_val in group_defaults[name].items():
+ val = copy.copy(default_val)
+ settings = group_randomizations[attr]
+ if settings["randomize"]:
+ # Randomize accordingly, and clip the final perturbed value
+ perturbation = np.random.rand() if type(val) in {int, float} else np.random.rand(*val.shape)
+ perturbation = settings["perturbation"] * (-1 + 2 * perturbation)
+ val = val + perturbation if settings["type"] == "size" else val * (1.0 + perturbation)
+ val = np.clip(val, *settings["clip"])
+ # Modify this value
+ self.mod(name=name, attr=attr, val=val)
+
+ # Make sure changes propagate in sim
+ self.update()
+
+ def update_sim(self, sim):
+ """
+ In addition to super method, update internal default values to match the current values from
+ (the presumably new) @sim.
+
+ Args:
+ sim (MjSim): MjSim object
+ """
+ super().update_sim(sim=sim)
+ self.save_defaults()
+
+ def update(self):
+ """
+ Propagates the changes made up to this point through the simulation
+ """
+ self.sim.forward()
+
+ def mod(self, name, attr, val):
+ """
+ General method to modify dynamics parameter @attr to be new value @val, associated with element @name.
+
+ Args:
+ name (str): Name of element to modify parameter. This can be a body, geom, or joint name. If modifying
+ an opt parameter, this should be set to None
+ attr (str): Name of the dynamics parameter to modify. Valid options are self.dynamics_parameters
+ val (int or float or n-array): New value(s) to set for the given dynamics parameter. The type of this
+ argument should match the expected type for the given parameter.
+ """
+ # Make sure specified parameter is valid, and then modify it
+ assert (
+ attr in self.dynamics_parameters
+ ), "Invalid dynamics parameter specified! Supported parameters are: {};" " requested: {}".format(
+ self.dynamics_parameters, attr
+ )
+ # Modify the requested parameter (uses a clean way to programmatically call the appropriate method)
+ getattr(self, f"mod_{attr}")(name, val)
+
+ def mod_density(self, name=None, val=0.0):
+ """
+ Modifies the global medium density of the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#option for more details.
+
+ Args:
+ name (str): Name for this element. Should be left as None (opt has no name attribute)
+ val (float): New density value.
+ """
+ # Make sure inputs are of correct form
+ assert name is None, "No name should be specified if modding density!"
+
+ # Modify this value
+ self.sim.model.opt.density = val
+
+ def mod_viscosity(self, name=None, val=0.0):
+ """
+ Modifies the global medium viscosity of the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#option for more details.
+
+ Args:
+ name (str): Name for this element. Should be left as None (opt has no name attribute)
+ val (float): New viscosity value.
+ """
+ # Make sure inputs are of correct form
+ assert name is None, "No name should be specified if modding density!"
+
+ # Modify this value
+ self.sim.model.opt.viscosity = val
+
+ def mod_position(self, name, val=(0, 0, 0)):
+ """
+ Modifies the @name's relative body position within the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#body for more details.
+
+ Args:
+ name (str): Name for this element.
+ val (3-array): New (x, y, z) relative position.
+ """
+ # Modify this value
+ body_id = self.sim.model.body_name2id(name)
+ self.sim.model.body_pos[body_id] = np.array(val)
+
+ def mod_quaternion(self, name, val=(1, 0, 0, 0)):
+ """
+ Modifies the @name's relative body orientation (quaternion) within the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#body for more details.
+
+ Note: This method automatically normalizes the inputted value.
+
+ Args:
+ name (str): Name for this element.
+ val (4-array): New (w, x, y, z) relative quaternion.
+ """
+ # Normalize the inputted value
+ val = np.array(val) / np.linalg.norm(val)
+ # Modify this value
+ body_id = self.sim.model.body_name2id(name)
+ self.sim.model.body_quat[body_id] = val
+
+ def mod_inertia(self, name, val):
+ """
+ Modifies the @name's relative body inertia within the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#body for more details.
+
+ Args:
+ name (str): Name for this element.
+ val (3-array): New (ixx, iyy, izz) diagonal values in the inertia matrix.
+ """
+ # Modify this value if it's not a dummy body
+ if name not in self.dummy_bodies:
+ body_id = self.sim.model.body_name2id(name)
+ self.sim.model.body_inertia[body_id] = np.array(val)
+
+ def mod_mass(self, name, val):
+ """
+ Modifies the @name's mass within the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#body for more details.
+
+ Args:
+ name (str): Name for this element.
+ val (float): New mass.
+ """
+ # Modify this value if it's not a dummy body
+ if name not in self.dummy_bodies:
+ body_id = self.sim.model.body_name2id(name)
+ self.sim.model.body_mass[body_id] = val
+
+ def mod_friction(self, name, val):
+ """
+ Modifies the @name's geom friction within the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#geom for more details.
+
+ Args:
+ name (str): Name for this element.
+ val (3-array): New (sliding, torsional, rolling) friction values.
+ """
+ # Modify this value
+ geom_id = self.sim.model.geom_name2id(name)
+ self.sim.model.geom_friction[geom_id] = np.array(val)
+
+ def mod_solref(self, name, val):
+ """
+ Modifies the @name's geom contact solver parameters within the simulation.
+ See http://www.mujoco.org/book/modeling.html#CSolver for more details.
+
+ Args:
+ name (str): Name for this element.
+ val (2-array): New (timeconst, dampratio) solref values.
+ """
+ # Modify this value
+ geom_id = self.sim.model.geom_name2id(name)
+ self.sim.model.geom_solref[geom_id] = np.array(val)
+
+ def mod_solimp(self, name, val):
+ """
+ Modifies the @name's geom contact solver impedance parameters within the simulation.
+ See http://www.mujoco.org/book/modeling.html#CSolver for more details.
+
+ Args:
+ name (str): Name for this element.
+ val (5-array): New (dmin, dmax, width, midpoint, power) solimp values.
+ """
+ # Modify this value
+ geom_id = self.sim.model.geom_name2id(name)
+ self.sim.model.geom_solimp[geom_id] = np.array(val)
+
+ def mod_stiffness(self, name, val):
+ """
+ Modifies the @name's joint stiffness within the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#joint for more details.
+
+ NOTE: If the stiffness is already at 0, we IGNORE this value since a non-stiff joint (i.e.: free-turning)
+ joint is fundamentally different than a stiffened joint)
+
+ Args:
+ name (str): Name for this element.
+ val (float): New stiffness.
+ """
+ # Modify this value (only if there is stiffness to begin with)
+ jnt_id = self.sim.model.joint_name2id(name)
+ if self.sim.model.jnt_stiffness[jnt_id] != 0:
+ self.sim.model.jnt_stiffness[jnt_id] = val
+
+ def mod_frictionloss(self, name, val):
+ """
+ Modifies the @name's joint frictionloss within the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#joint for more details.
+
+ NOTE: If the requested joint is a free joint, it will be ignored since it does not
+ make physical sense to have friction loss associated with this joint (air drag / damping
+ is already captured implicitly by the medium density / viscosity values)
+
+ Args:
+ name (str): Name for this element.
+ val (float): New friction loss.
+ """
+ # Modify this value (only if it's not a free joint)
+ jnt_id = self.sim.model.joint_name2id(name)
+ if self.sim.model.jnt_type[jnt_id] != 0:
+ dof_idx = [i for i, v in enumerate(self.sim.model.dof_jntid) if v == jnt_id]
+ self.sim.model.dof_frictionloss[dof_idx] = val
+
+ def mod_damping(self, name, val):
+ """
+ Modifies the @name's joint damping within the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#joint for more details.
+
+ NOTE: If the requested joint is a free joint, it will be ignored since it does not
+ make physical sense to have damping associated with this joint (air drag / damping
+ is already captured implicitly by the medium density / viscosity values)
+
+ Args:
+ name (str): Name for this element.
+ val (float): New damping.
+ """
+ # Modify this value (only if it's not a free joint)
+ jnt_id = self.sim.model.joint_name2id(name)
+ if self.sim.model.jnt_type[jnt_id] != 0:
+ dof_idx = [i for i, v in enumerate(self.sim.model.dof_jntid) if v == jnt_id]
+ self.sim.model.dof_damping[dof_idx] = val
+
+ def mod_armature(self, name, val):
+ """
+ Modifies the @name's joint armature within the simulation.
+ See http://www.mujoco.org/book/XMLreference.html#joint for more details.
+
+ Args:
+ name (str): Name for this element.
+ val (float): New armature.
+ """
+ # Modify this value (only if it's not a free joint)
+ jnt_id = self.sim.model.joint_name2id(name)
+ if self.sim.model.jnt_type[jnt_id] != 0:
+ dof_idx = [i for i, v in enumerate(self.sim.model.dof_jntid) if v == jnt_id]
+ self.sim.model.dof_armature[dof_idx] = val
+
+ @property
+ def dynamics_parameters(self):
+ """
+ Returns:
+ set: All dynamics parameters that can be randomized using this modder.
+ """
+ return {
+ # Opt parameters
+ "density",
+ "viscosity",
+ # Body parameters
+ "position",
+ "quaternion",
+ "inertia",
+ "mass",
+ # Geom parameters
+ "friction",
+ "solref",
+ "solimp",
+ # Joint parameters
+ "stiffness",
+ "frictionloss",
+ "damping",
+ "armature",
+ }
+
+ @property
+ def opt(self):
+ """
+ Returns:
+ PyMjOption: MjModel sim options
+ """
+ return self.sim.model.opt
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/numba.py b/phantom/submodules/phantom-robosuite/robosuite/utils/numba.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd5d8549758e24993abbd72bbf970d90c0e0091b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/numba.py
@@ -0,0 +1,12 @@
+"""
+Numba utils.
+"""
+import numba
+
+import robosuite.macros as macros
+
+
+def jit_decorator(func):
+ if macros.ENABLE_NUMBA:
+ return numba.jit(nopython=True, cache=macros.CACHE_NUMBA)(func)
+ return func
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/observables.py b/phantom/submodules/phantom-robosuite/robosuite/utils/observables.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c23e8189bcdc1d84a47efe14299a661e11de468
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/observables.py
@@ -0,0 +1,403 @@
+import numpy as np
+
+
+def sensor(modality):
+ """
+ Decorator that should be added to any sensors that will be an observable.
+
+ Decorated functions should have signature:
+
+ any = func(obs_cache)
+
+ Where @obs_cache is a dictionary mapping observable keys to pre-computed values, and @any is either a scalar
+ or array. This function should also handle the case if obs_cache is either None or an empty dict.
+
+ An example use case is shown below:
+
+ >>> @sensor(modality="proprio")
+ >>> def joint_pos(obs_cache):
+ # Always handle case if obs_cache is empty
+ if not obs_cache:
+ return np.zeros(7)
+ # Otherwise, run necessary calculations and return output
+ ...
+ out = ...
+ return out
+
+ Args:
+ modality (str): Modality for this sensor
+
+ Returns:
+ function: decorator function
+ """
+ # Define standard decorator (with no args)
+ def decorator(func):
+ # Add modality attribute
+ func.__modality__ = modality
+ # Return function
+ return func
+
+ return decorator
+
+
+def create_deterministic_corrupter(corruption, low=-np.inf, high=np.inf):
+ """
+ Creates a deterministic corrupter that applies the same corrupted value to all sensor values
+
+ Args:
+ corruption (float): Corruption to apply
+ low (float): Minimum value for output for clipping
+ high (float): Maximum value for output for clipping
+
+ Returns:
+ function: corrupter
+ """
+
+ def corrupter(inp):
+ inp = np.array(inp)
+ return np.clip(inp + corruption, low, high)
+
+ return corrupter
+
+
+def create_uniform_noise_corrupter(min_noise, max_noise, low=-np.inf, high=np.inf):
+ """
+ Creates a corrupter that applies uniform noise to a given input within range @low to @high
+
+ Args:
+ min_noise (float): Minimum noise to apply
+ max_noise (float): Maximum noise to apply
+ low (float): Minimum value for output for clipping
+ high (float): Maxmimum value for output for clipping
+
+ Returns:
+ function: corrupter
+ """
+
+ def corrupter(inp):
+ inp = np.array(inp)
+ noise = (max_noise - min_noise) * np.random.random_sample(inp.shape) + min_noise
+ return np.clip(inp + noise, low, high)
+
+ return corrupter
+
+
+def create_gaussian_noise_corrupter(mean, std, low=-np.inf, high=np.inf):
+ """
+ Creates a corrupter that applies gaussian noise to a given input with mean @mean and std dev @std
+
+ Args:
+ mean (float): Mean of the noise to apply
+ std (float): Standard deviation of the noise to apply
+ low (float): Minimum value for output for clipping
+ high (float): Maxmimum value for output for clipping
+
+ Returns:
+ function: corrupter
+ """
+
+ def corrupter(inp):
+ inp = np.array(inp)
+ noise = mean + std * np.random.randn(*inp.shape)
+ return np.clip(inp + noise, low, high)
+
+ return corrupter
+
+
+def create_deterministic_delayer(delay):
+ """
+ Create a deterministic delayer that always returns the same delay value
+
+ Args:
+ delay (float): Delay value to return
+
+ Returns:
+ function: delayer
+ """
+ assert delay >= 0, "Inputted delay must be non-negative!"
+ return lambda: delay
+
+
+def create_uniform_sampled_delayer(min_delay, max_delay):
+ """
+ Creates uniformly sampled delayer, with minimum delay @low and maximum delay @high, both inclusive
+
+ Args:
+ min_delay (float): Minimum possible delay
+ max_delay (float): Maxmimum possible delay
+
+ Returns:
+ function: delayer
+ """
+ assert min(min_delay, max_delay) >= 0, "Inputted delay must be non-negative!"
+ return lambda: min_delay + (max_delay - min_delay) * np.random.random()
+
+
+def create_gaussian_sampled_delayer(mean, std):
+ """
+ Creates a gaussian sampled delayer, with average delay @mean which varies by standard deviation @std
+
+ Args:
+ mean (float): Average delay
+ std (float): Standard deviation of the delay variation
+
+ Returns:
+ function: delayer
+ """
+ assert mean >= 0, "Inputted mean delay must be non-negative!"
+ return lambda: max(0.0, int(np.round(mean + std * np.random.randn())))
+
+
+# Common defaults to use
+NO_CORRUPTION = lambda inp: inp
+NO_FILTER = lambda inp: inp
+NO_DELAY = lambda: 0.0
+
+
+class Observable:
+ """
+ Base class for all observables -- defines interface for interacting with sensors
+
+ Args:
+ name (str): Name for this observable
+ sensor (function with `sensor` decorator): Method to grab raw sensor data for this observable. Should take in a
+ single dict argument (observation cache if a pre-computed value is required) and return the raw sensor data
+ for the current timestep. Must handle case if inputted argument is empty ({}), and should have `sensor`
+ decorator when defined
+ corrupter (None or function): Method to corrupt the raw sensor data for this observable. Should take in
+ the output of @sensor and return the same type (corrupted data). If None, results in default no corruption
+ filter (None or function): Method to filter the outputted reading for this observable. Should take in the output
+ of @corrupter and return the same type (filtered data). If None, results in default no filter. Note that
+ this function can also double as an observer, where sampled data is recorded by this function.
+ delayer (None or function): Method to delay the raw sensor data when polling this observable. Should take in
+ no arguments and return a float, for the number of seconds to delay the measurement by. If None, results in
+ default no delayer
+ sampling_rate (float): Sampling rate for this observable (Hz)
+ enabled (bool): Whether this sensor is enabled or not. If enabled, this observable's values
+ are continually computed / updated every time update() is called.
+ active (bool): Whether this sensor is active or not. If active, this observable's current
+ observed value is returned from self.obs, otherwise self.obs returns None.
+ """
+
+ def __init__(
+ self,
+ name,
+ sensor,
+ corrupter=None,
+ filter=None,
+ delayer=None,
+ sampling_rate=20,
+ enabled=True,
+ active=True,
+ ):
+ # Set all internal variables and methods
+ self.name = name
+ self._sensor = sensor
+ self._corrupter = corrupter if corrupter is not None else NO_CORRUPTION
+ self._filter = filter if filter is not None else NO_FILTER
+ self._delayer = delayer if delayer is not None else NO_DELAY
+ self._sampling_timestep = 1.0 / sampling_rate
+ self._enabled = enabled
+ self._active = active
+ self._is_number = False # filled in during sensor check call
+ self._data_shape = (1,) # filled in during sensor check call
+
+ # Make sure sensor is working
+ self._check_sensor_validity()
+
+ # These values will be modified during update() call
+ self._time_since_last_sample = 0.0 # seconds
+ self._current_delay = self._delayer() # seconds
+ self._current_observed_value = 0 if self._is_number else np.zeros(self._data_shape)
+ self._sampled = False
+
+ def update(self, timestep, obs_cache, force=False):
+ """
+ Updates internal values for this observable, if enabled.
+
+ Args:
+ timestep (float): Amount of simulation time (in sec) that has passed since last call.
+ obs_cache (dict): Observation cache mapping observable names to pre-computed values to pass to sensor. This
+ will be updated in-place during this call.
+ force (bool): If True, will force the observable to update its internal value to the newest value.
+ """
+ if self._enabled:
+ # Increment internal time counter
+ self._time_since_last_sample += timestep
+
+ # If the delayed sampling time has been passed and we haven't sampled yet for this sampling period,
+ # we should grab a new measurement
+ if (
+ not self._sampled and self._sampling_timestep - self._current_delay >= self._time_since_last_sample
+ ) or force:
+ # Get newest raw value, corrupt it, filter it, and set it as our current observed value
+ obs = np.array(self._filter(self._corrupter(self._sensor(obs_cache))))
+ self._current_observed_value = obs[0] if len(obs.shape) == 1 and obs.shape[0] == 1 else obs
+ # Update cache entry as well
+ obs_cache[self.name] = np.array(self._current_observed_value)
+ # Toggle sampled and re-sample next time delay
+ self._sampled = True
+ self._current_delay = self._delayer()
+
+ # If our total time since last sample has surpassed our sampling timestep,
+ # then we reset our timer and sampled flag
+ if self._time_since_last_sample >= self._sampling_timestep:
+ if not self._sampled:
+ # If we still haven't sampled yet, sample immediately and warn user that sampling rate is too low
+ print(
+ f"Warning: sampling rate for observable {self.name} is either too low or delay is too high. "
+ f"Please adjust one (or both)"
+ )
+ # Get newest raw value, corrupt it, filter it, and set it as our current observed value
+ obs = np.array(self._filter(self._corrupter(self._sensor(obs_cache))))
+ self._current_observed_value = obs[0] if len(obs.shape) == 1 and obs.shape[0] == 1 else obs
+ # Update cache entry as well
+ obs_cache[self.name] = np.array(self._current_observed_value)
+ # Re-sample next time delay
+ self._current_delay = self._delayer()
+ self._time_since_last_sample %= self._sampling_timestep
+ self._sampled = False
+
+ def reset(self):
+ """
+ Resets this observable's internal values (but does not reset its sensor, corrupter, delayer, or filter)
+ """
+ self._time_since_last_sample = 0.0
+ self._current_delay = self._delayer()
+ self._current_observed_value = 0 if self._is_number else np.zeros(self._data_shape)
+
+ def is_enabled(self):
+ """
+ Determines whether observable is enabled or not. This observable is considered enabled if its values
+ are being continually computed / updated during each update() call.
+
+ Returns:
+ bool: True if this observable is enabled
+ """
+ return self._enabled
+
+ def is_active(self):
+ """
+ Determines whether observable is active or not. This observable is considered active if its current observation
+ value is being returned in self.obs.
+
+ Returns:
+ bool: True if this observable is active
+ """
+ return self._active
+
+ def set_enabled(self, enabled):
+ """
+ Sets whether this observable is enabled or not. If enabled, this observable's values
+ are continually computed / updated every time update() is called.
+
+ Args:
+ enabled (bool): True if this observable should be enabled
+ """
+ self._enabled = enabled
+ # Reset values
+ self.reset()
+
+ def set_active(self, active):
+ """
+ Sets whether this observable is active or not. If active, this observable's current
+ observed value is returned from self.obs, otherwise self.obs returns None.
+
+ Args:
+ active (bool): True if this observable should be active
+ """
+ self._active = active
+
+ def set_sensor(self, sensor):
+ """
+ Sets the sensor for this observable.
+
+ Args:
+ sensor (function with sensor decorator): Method to grab raw sensor data for this observable. Should take in
+ a single dict argument (observation cache if a pre-computed value is required) and return the raw
+ sensor data for the current timestep. Must handle case if inputted argument is empty ({}), and should
+ have `sensor` decorator when defined
+ """
+ self._sensor = sensor
+ self._check_sensor_validity()
+
+ def set_corrupter(self, corrupter):
+ """
+ Sets the corrupter for this observable.
+
+ Args:
+ corrupter (None or function): Method to corrupt the raw sensor data for this observable. Should take in
+ the output of self.sensor and return the same type (corrupted data).
+ If None, results in default no corruption
+ """
+ self._corrupter = corrupter if corrupter is not None else NO_CORRUPTION
+
+ def set_filter(self, filter):
+ """
+ Sets the filter for this observable. Note that this function can also double as an observer, where sampled
+ data is recorded by this function.
+
+ Args:
+ filter (None or function): Method to filter the outputted reading for this observable. Should take in
+ the output of @corrupter and return the same type (filtered data).
+ If None, results in default no filter
+ """
+ self._filter = filter if filter is not None else NO_FILTER
+
+ def set_delayer(self, delayer):
+ """
+ Sets the delayer for this observable.
+
+ Args:
+ delayer (None or function): Method to delay the raw sensor data when polling this observable. Should take
+ in no arguments and return a float, for the number of seconds to delay the measurement by.
+ If None, results in default no filter
+ """
+ self._delayer = delayer if delayer is not None else NO_DELAY
+
+ def set_sampling_rate(self, rate):
+ """
+ Sets the sampling rate for this observable.
+
+ Args:
+ rate (int): New sampling rate for this observable (Hz)
+ """
+ self._sampling_timestep = 1.0 / rate
+
+ def _check_sensor_validity(self):
+ """
+ Internal function that checks the validity of this observable's sensor. It does the following:
+
+ - Asserts that the inputted sensor has its __modality__ attribute defined from the sensor decorator
+ - Asserts that the inputted sensor can handle the empty dict {} arg case
+ - Updates the corresponding name, and data-types for this sensor
+ """
+ try:
+ _ = self.modality
+ self._data_shape = np.array(self._sensor({})).shape
+ self._is_number = len(self._data_shape) == 1 and self._data_shape[0] == 1
+ except Exception as e:
+ from robosuite.utils.log_utils import ROBOSUITE_DEFAULT_LOGGER
+
+ ROBOSUITE_DEFAULT_LOGGER.error(e)
+ raise ValueError("Current sensor for observable {} is invalid.".format(self.name))
+
+ @property
+ def obs(self):
+ """
+ Current observation from this observable
+
+ Returns:
+ None or float or np.array: If active, current observed value from this observable. Otherwise, None
+ """
+ return self._current_observed_value if self._active else None
+
+ @property
+ def modality(self):
+ """
+ Modality of this sensor
+
+ Returns:
+ str: Modality name for this observable
+ """
+ return self._sensor.__modality__
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/opencv_renderer.py b/phantom/submodules/phantom-robosuite/robosuite/utils/opencv_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6d6386b75870248519a12a18e62c233e256c246
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/opencv_renderer.py
@@ -0,0 +1,50 @@
+"""
+opencv renderer class.
+"""
+import cv2
+import numpy as np
+
+
+class OpenCVRenderer:
+ def __init__(self, sim):
+ # TODO: update this appropriately - need to get screen dimensions
+ self.width = 1280
+ self.height = 800
+
+ self.sim = sim
+ self.camera_name = self.sim.model.camera_id2name(0)
+
+ self.keypress_callback = None
+
+ def set_camera(self, camera_id):
+ """
+ Set the camera view to the specified camera ID.
+ Args:
+ camera_id (int): id of the camera to set the current viewer to
+ """
+ self.camera_name = self.sim.model.camera_id2name(camera_id)
+
+ def render(self):
+ # get frame with offscreen renderer (assumes that the renderer already exists)
+ im = self.sim.render(camera_name=self.camera_name, height=self.height, width=self.width)[..., ::-1]
+
+ # write frame to window
+ im = np.flip(im, axis=0)
+ cv2.imshow("offscreen render", im)
+ key = cv2.waitKey(1)
+ if self.keypress_callback:
+ self.keypress_callback(key)
+
+ def add_keypress_callback(self, keypress_callback):
+ self.keypress_callback = keypress_callback
+
+ def close(self):
+ """
+ Any cleanup to close renderer.
+ """
+
+ # NOTE: assume that @sim will get cleaned up outside the renderer - just delete the reference
+ self.sim = None
+
+ # close window
+ cv2.destroyAllWindows()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/placement_samplers.py b/phantom/submodules/phantom-robosuite/robosuite/utils/placement_samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bd8bb99712d8dac51ca5acf90261ca21cba60c4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/placement_samplers.py
@@ -0,0 +1,441 @@
+import collections
+from copy import copy
+
+import numpy as np
+
+from robosuite.models.objects import MujocoObject
+from robosuite.utils import RandomizationError
+from robosuite.utils.transform_utils import quat_multiply
+
+
+class ObjectPositionSampler:
+ """
+ Base class of object placement sampler.
+
+ Args:
+ name (str): Name of this sampler.
+
+ mujoco_objects (None or MujocoObject or list of MujocoObject): single model or list of MJCF object models
+
+ ensure_object_boundary_in_range (bool): If True, will ensure that the object is enclosed within a given boundary
+ (should be implemented by subclass)
+
+ ensure_valid_placement (bool): If True, will check for correct (valid) object placements
+
+ reference_pos (3-array): global (x,y,z) position relative to which sampling will occur
+
+ z_offset (float): Add a small z-offset to placements. This is useful for fixed objects
+ that do not move (i.e. no free joint) to place them above the table.
+ """
+
+ def __init__(
+ self,
+ name,
+ mujoco_objects=None,
+ ensure_object_boundary_in_range=True,
+ ensure_valid_placement=True,
+ reference_pos=(0, 0, 0),
+ z_offset=0.0,
+ ):
+ # Setup attributes
+ self.name = name
+ if mujoco_objects is None:
+ self.mujoco_objects = []
+ else:
+ # Shallow copy the list so we don't modify the inputted list but still keep the object references
+ self.mujoco_objects = [mujoco_objects] if isinstance(mujoco_objects, MujocoObject) else copy(mujoco_objects)
+ self.ensure_object_boundary_in_range = ensure_object_boundary_in_range
+ self.ensure_valid_placement = ensure_valid_placement
+ self.reference_pos = reference_pos
+ self.z_offset = z_offset
+
+ def add_objects(self, mujoco_objects):
+ """
+ Add additional objects to this sampler. Checks to make sure there's no identical objects already stored.
+
+ Args:
+ mujoco_objects (MujocoObject or list of MujocoObject): single model or list of MJCF object models
+ """
+ mujoco_objects = [mujoco_objects] if isinstance(mujoco_objects, MujocoObject) else mujoco_objects
+ for obj in mujoco_objects:
+ assert obj not in self.mujoco_objects, "Object '{}' already in sampler!".format(obj.name)
+ self.mujoco_objects.append(obj)
+
+ def reset(self):
+ """
+ Resets this sampler. Removes all mujoco objects from this sampler.
+ """
+ self.mujoco_objects = []
+
+ def sample(self, fixtures=None, reference=None, on_top=True):
+ """
+ Uniformly sample on a surface (not necessarily table surface).
+
+ Args:
+ fixtures (dict): dictionary of current object placements in the scene as well as any other relevant
+ obstacles that should not be in contact with newly sampled objects. Used to make sure newly
+ generated placements are valid. Should be object names mapped to (pos, quat, MujocoObject)
+
+ reference (str or 3-tuple or None): if provided, sample relative placement. Can either be a string, which
+ corresponds to an existing object found in @fixtures, or a direct (x,y,z) value. If None, will sample
+ relative to this sampler's `'reference_pos'` value.
+
+ on_top (bool): if True, sample placement on top of the reference object.
+
+ Return:
+ dict: dictionary of all object placements, mapping object_names to (pos, quat, obj), including the
+ placements specified in @fixtures. Note quat is in (w,x,y,z) form
+ """
+ raise NotImplementedError
+
+
+class UniformRandomSampler(ObjectPositionSampler):
+ """
+ Places all objects within the table uniformly random.
+
+ Args:
+ name (str): Name of this sampler.
+
+ mujoco_objects (None or MujocoObject or list of MujocoObject): single model or list of MJCF object models
+
+ x_range (2-array of float): Specify the (min, max) relative x_range used to uniformly place objects
+
+ y_range (2-array of float): Specify the (min, max) relative y_range used to uniformly place objects
+
+ rotation (None or float or Iterable):
+ :`None`: Add uniform random random rotation
+ :`Iterable (a,b)`: Uniformly randomize rotation angle between a and b (in radians)
+ :`value`: Add fixed angle rotation
+
+ rotation_axis (str): Can be 'x', 'y', or 'z'. Axis about which to apply the requested rotation
+
+ ensure_object_boundary_in_range (bool):
+ :`True`: The center of object is at position:
+ [uniform(min x_range + radius, max x_range - radius)], [uniform(min x_range + radius, max x_range - radius)]
+ :`False`:
+ [uniform(min x_range, max x_range)], [uniform(min x_range, max x_range)]
+
+ ensure_valid_placement (bool): If True, will check for correct (valid) object placements
+
+ reference_pos (3-array): global (x,y,z) position relative to which sampling will occur
+
+ z_offset (float): Add a small z-offset to placements. This is useful for fixed objects
+ that do not move (i.e. no free joint) to place them above the table.
+ """
+
+ def __init__(
+ self,
+ name,
+ mujoco_objects=None,
+ x_range=(0, 0),
+ y_range=(0, 0),
+ rotation=None,
+ rotation_axis="z",
+ ensure_object_boundary_in_range=True,
+ ensure_valid_placement=True,
+ reference_pos=(0, 0, 0),
+ z_offset=0.0,
+ ):
+ self.x_range = x_range
+ self.y_range = y_range
+ self.rotation = rotation
+ self.rotation_axis = rotation_axis
+
+ super().__init__(
+ name=name,
+ mujoco_objects=mujoco_objects,
+ ensure_object_boundary_in_range=ensure_object_boundary_in_range,
+ ensure_valid_placement=ensure_valid_placement,
+ reference_pos=reference_pos,
+ z_offset=z_offset,
+ )
+
+ def _sample_x(self, object_horizontal_radius):
+ """
+ Samples the x location for a given object
+
+ Args:
+ object_horizontal_radius (float): Radius of the object currently being sampled for
+
+ Returns:
+ float: sampled x position
+ """
+ minimum, maximum = self.x_range
+ if self.ensure_object_boundary_in_range:
+ minimum += object_horizontal_radius
+ maximum -= object_horizontal_radius
+ return np.random.uniform(high=maximum, low=minimum)
+
+ def _sample_y(self, object_horizontal_radius):
+ """
+ Samples the y location for a given object
+
+ Args:
+ object_horizontal_radius (float): Radius of the object currently being sampled for
+
+ Returns:
+ float: sampled y position
+ """
+ minimum, maximum = self.y_range
+ if self.ensure_object_boundary_in_range:
+ minimum += object_horizontal_radius
+ maximum -= object_horizontal_radius
+ return np.random.uniform(high=maximum, low=minimum)
+
+ def _sample_quat(self):
+ """
+ Samples the orientation for a given object
+
+ Returns:
+ np.array: sampled object quaternion in (w,x,y,z) form
+
+ Raises:
+ ValueError: [Invalid rotation axis]
+ """
+ if self.rotation is None:
+ rot_angle = np.random.uniform(high=2 * np.pi, low=0)
+ elif isinstance(self.rotation, collections.abc.Iterable):
+ rot_angle = np.random.uniform(high=max(self.rotation), low=min(self.rotation))
+ else:
+ rot_angle = self.rotation
+
+ # Return angle based on axis requested
+ if self.rotation_axis == "x":
+ return np.array([np.cos(rot_angle / 2), np.sin(rot_angle / 2), 0, 0])
+ elif self.rotation_axis == "y":
+ return np.array([np.cos(rot_angle / 2), 0, np.sin(rot_angle / 2), 0])
+ elif self.rotation_axis == "z":
+ return np.array([np.cos(rot_angle / 2), 0, 0, np.sin(rot_angle / 2)])
+ else:
+ # Invalid axis specified, raise error
+ raise ValueError(
+ "Invalid rotation axis specified. Must be 'x', 'y', or 'z'. Got: {}".format(self.rotation_axis)
+ )
+
+ def sample(self, fixtures=None, reference=None, on_top=True):
+ """
+ Uniformly sample relative to this sampler's reference_pos or @reference (if specified).
+
+ Args:
+ fixtures (dict): dictionary of current object placements in the scene as well as any other relevant
+ obstacles that should not be in contact with newly sampled objects. Used to make sure newly
+ generated placements are valid. Should be object names mapped to (pos, quat, MujocoObject)
+
+ reference (str or 3-tuple or None): if provided, sample relative placement. Can either be a string, which
+ corresponds to an existing object found in @fixtures, or a direct (x,y,z) value. If None, will sample
+ relative to this sampler's `'reference_pos'` value.
+
+ on_top (bool): if True, sample placement on top of the reference object. This corresponds to a sampled
+ z-offset of the current sampled object's bottom_offset + the reference object's top_offset
+ (if specified)
+
+ Return:
+ dict: dictionary of all object placements, mapping object_names to (pos, quat, obj), including the
+ placements specified in @fixtures. Note quat is in (w,x,y,z) form
+
+ Raises:
+ RandomizationError: [Cannot place all objects]
+ AssertionError: [Reference object name does not exist, invalid inputs]
+ """
+ # Standardize inputs
+ placed_objects = {} if fixtures is None else copy(fixtures)
+ if reference is None:
+ base_offset = self.reference_pos
+ elif type(reference) is str:
+ assert (
+ reference in placed_objects
+ ), "Invalid reference received. Current options are: {}, requested: {}".format(
+ placed_objects.keys(), reference
+ )
+ ref_pos, _, ref_obj = placed_objects[reference]
+ base_offset = np.array(ref_pos)
+ if on_top:
+ base_offset += np.array((0, 0, ref_obj.top_offset[-1]))
+ else:
+ base_offset = np.array(reference)
+ assert (
+ base_offset.shape[0] == 3
+ ), "Invalid reference received. Should be (x,y,z) 3-tuple, but got: {}".format(base_offset)
+
+ # Sample pos and quat for all objects assigned to this sampler
+ for obj in self.mujoco_objects:
+ # First make sure the currently sampled object hasn't already been sampled
+ assert obj.name not in placed_objects, "Object '{}' has already been sampled!".format(obj.name)
+
+ horizontal_radius = obj.horizontal_radius
+ bottom_offset = obj.bottom_offset
+ success = False
+ for i in range(5000): # 5000 retries
+ object_x = self._sample_x(horizontal_radius) + base_offset[0]
+ object_y = self._sample_y(horizontal_radius) + base_offset[1]
+ object_z = self.z_offset + base_offset[2]
+ if on_top:
+ object_z -= bottom_offset[-1]
+
+ # objects cannot overlap
+ location_valid = True
+ if self.ensure_valid_placement:
+ for (x, y, z), _, other_obj in placed_objects.values():
+ if (
+ np.linalg.norm((object_x - x, object_y - y))
+ <= other_obj.horizontal_radius + horizontal_radius
+ ) and (object_z - z <= other_obj.top_offset[-1] - bottom_offset[-1]):
+ location_valid = False
+ break
+
+ if location_valid:
+ # random rotation
+ quat = self._sample_quat()
+
+ # multiply this quat by the object's initial rotation if it has the attribute specified
+ if hasattr(obj, "init_quat"):
+ quat = quat_multiply(quat, obj.init_quat)
+
+ # location is valid, put the object down
+ pos = (object_x, object_y, object_z)
+ placed_objects[obj.name] = (pos, quat, obj)
+ success = True
+ break
+
+ if not success:
+ raise RandomizationError("Cannot place all objects ):")
+
+ return placed_objects
+
+
+class SequentialCompositeSampler(ObjectPositionSampler):
+ """
+ Samples position for each object sequentially. Allows chaining
+ multiple placement initializers together - so that object locations can
+ be sampled on top of other objects or relative to other object placements.
+
+ Args:
+ name (str): Name of this sampler.
+ """
+
+ def __init__(self, name):
+ # Samplers / args will be filled in later
+ self.samplers = collections.OrderedDict()
+ self.sample_args = collections.OrderedDict()
+
+ super().__init__(name=name)
+
+ def append_sampler(self, sampler, sample_args=None):
+ """
+ Adds a new placement initializer with corresponding @sampler and arguments
+
+ Args:
+ sampler (ObjectPositionSampler): sampler to add
+ sample_args (None or dict): If specified, should be additional arguments to pass to @sampler's sample()
+ call. Should map corresponding sampler's arguments to values (excluding @fixtures argument)
+
+ Raises:
+ AssertionError: [Object name in samplers]
+ """
+ # Verify that all added mujoco objects haven't already been added, and add to this sampler's objects dict
+ for obj in sampler.mujoco_objects:
+ assert obj not in self.mujoco_objects, f"Object '{obj.name}' already has sampler associated with it!"
+ self.mujoco_objects.append(obj)
+ self.samplers[sampler.name] = sampler
+ self.sample_args[sampler.name] = sample_args
+
+ def hide(self, mujoco_objects):
+ """
+ Helper method to remove an object from the workspace.
+
+ Args:
+ mujoco_objects (MujocoObject or list of MujocoObject): Object(s) to hide
+ """
+ sampler = UniformRandomSampler(
+ name="HideSampler",
+ mujoco_objects=mujoco_objects,
+ x_range=[-10, -20],
+ y_range=[-10, -20],
+ rotation=[0, 0],
+ rotation_axis="z",
+ z_offset=10,
+ ensure_object_boundary_in_range=False,
+ ensure_valid_placement=False,
+ )
+ self.append_sampler(sampler=sampler)
+
+ def add_objects(self, mujoco_objects):
+ """
+ Override super method to make sure user doesn't call this (all objects should implicitly belong to sub-samplers)
+ """
+ raise AttributeError("add_objects() should not be called for SequentialCompsiteSamplers!")
+
+ def add_objects_to_sampler(self, sampler_name, mujoco_objects):
+ """
+ Adds specified @mujoco_objects to sub-sampler with specified @sampler_name.
+
+ Args:
+ sampler_name (str): Existing sub-sampler name
+ mujoco_objects (MujocoObject or list of MujocoObject): Object(s) to add
+ """
+ # First verify that all mujoco objects haven't already been added, and add to this sampler's objects dict
+ mujoco_objects = [mujoco_objects] if isinstance(mujoco_objects, MujocoObject) else mujoco_objects
+ for obj in mujoco_objects:
+ assert obj not in self.mujoco_objects, f"Object '{obj.name}' already has sampler associated with it!"
+ self.mujoco_objects.append(obj)
+ # Make sure sampler_name exists
+ assert (
+ sampler_name in self.samplers.keys()
+ ), "Invalid sub-sampler specified, valid options are: {}, " "requested: {}".format(
+ self.samplers.keys(), sampler_name
+ )
+ # Add the mujoco objects to the requested sub-sampler
+ self.samplers[sampler_name].add_objects(mujoco_objects)
+
+ def reset(self):
+ """
+ Resets this sampler. In addition to base method, iterates over all sub-samplers and resets them
+ """
+ super().reset()
+ for sampler in self.samplers.values():
+ sampler.reset()
+
+ def sample(self, fixtures=None, reference=None, on_top=True):
+ """
+ Sample from each placement initializer sequentially, in the order
+ that they were appended.
+
+ Args:
+ fixtures (dict): dictionary of current object placements in the scene as well as any other relevant
+ obstacles that should not be in contact with newly sampled objects. Used to make sure newly
+ generated placements are valid. Should be object names mapped to (pos, quat, MujocoObject)
+
+ reference (str or 3-tuple or None): if provided, sample relative placement. This will override each
+ sampler's @reference argument if not already specified. Can either be a string, which
+ corresponds to an existing object found in @fixtures, or a direct (x,y,z) value. If None, will sample
+ relative to this sampler's `'reference_pos'` value.
+
+ on_top (bool): if True, sample placement on top of the reference object. This will override each
+ sampler's @on_top argument if not already specified. This corresponds to a sampled
+ z-offset of the current sampled object's bottom_offset + the reference object's top_offset
+ (if specified)
+
+ Return:
+ dict: dictionary of all object placements, mapping object_names to (pos, quat, obj), including the
+ placements specified in @fixtures. Note quat is in (w,x,y,z) form
+
+ Raises:
+ RandomizationError: [Cannot place all objects]
+ """
+ # Standardize inputs
+ placed_objects = {} if fixtures is None else copy(fixtures)
+
+ # Iterate through all samplers to sample
+ for sampler, s_args in zip(self.samplers.values(), self.sample_args.values()):
+ # Pre-process sampler args
+ if s_args is None:
+ s_args = {}
+ for arg_name, arg in zip(("reference", "on_top"), (reference, on_top)):
+ if arg_name not in s_args:
+ s_args[arg_name] = arg
+ # Run sampler
+ new_placements = sampler.sample(fixtures=placed_objects, **s_args)
+ # Update placements
+ placed_objects.update(new_placements)
+
+ return placed_objects
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/robot_utils.py b/phantom/submodules/phantom-robosuite/robosuite/utils/robot_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..53200dcd6d64c56c6a622b684ba10c62498e6a40
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/robot_utils.py
@@ -0,0 +1,16 @@
+# Utilities functions for working with robots
+
+from robosuite.robots import BIMANUAL_ROBOTS
+
+
+def check_bimanual(robot_name):
+ """
+ Utility function that returns whether the inputted robot_name is a bimanual robot or not
+
+ Args:
+ robot_name (str): Name of the robot to check
+
+ Returns:
+ bool: True if the inputted robot is a bimanual robot
+ """
+ return robot_name.lower() in BIMANUAL_ROBOTS
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/sim_utils.py b/phantom/submodules/phantom-robosuite/robosuite/utils/sim_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5452add1a7ee82b947e81c76bd43698ff128047d
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/sim_utils.py
@@ -0,0 +1,67 @@
+"""
+Collection of useful simulation utilities
+"""
+
+from robosuite.models.base import MujocoModel
+
+
+def check_contact(sim, geoms_1, geoms_2=None):
+ """
+ Finds contact between two geom groups.
+ Args:
+ sim (MjSim): Current simulation object
+ geoms_1 (str or list of str or MujocoModel): an individual geom name or list of geom names or a model. If
+ a MujocoModel is specified, the geoms checked will be its contact_geoms
+ geoms_2 (str or list of str or MujocoModel or None): another individual geom name or list of geom names.
+ If a MujocoModel is specified, the geoms checked will be its contact_geoms. If None, will check
+ any collision with @geoms_1 to any other geom in the environment
+ Returns:
+ bool: True if any geom in @geoms_1 is in contact with any geom in @geoms_2.
+ """
+ # Check if either geoms_1 or geoms_2 is a string, convert to list if so
+ if type(geoms_1) is str:
+ geoms_1 = [geoms_1]
+ elif isinstance(geoms_1, MujocoModel):
+ geoms_1 = geoms_1.contact_geoms
+ if type(geoms_2) is str:
+ geoms_2 = [geoms_2]
+ elif isinstance(geoms_2, MujocoModel):
+ geoms_2 = geoms_2.contact_geoms
+ for i in range(sim.data.ncon):
+ contact = sim.data.contact[i]
+ # check contact geom in geoms
+ c1_in_g1 = sim.model.geom_id2name(contact.geom1) in geoms_1
+ c2_in_g2 = sim.model.geom_id2name(contact.geom2) in geoms_2 if geoms_2 is not None else True
+ # check contact geom in geoms (flipped)
+ c2_in_g1 = sim.model.geom_id2name(contact.geom2) in geoms_1
+ c1_in_g2 = sim.model.geom_id2name(contact.geom1) in geoms_2 if geoms_2 is not None else True
+ if (c1_in_g1 and c2_in_g2) or (c1_in_g2 and c2_in_g1):
+ return True
+ return False
+
+
+def get_contacts(sim, model):
+ """
+ Checks for any contacts with @model (as defined by @model's contact_geoms) and returns the set of
+ geom names currently in contact with that model (excluding the geoms that are part of the model itself).
+ Args:
+ sim (MjSim): Current simulation model
+ model (MujocoModel): Model to check contacts for.
+ Returns:
+ set: Unique geoms that are actively in contact with this model.
+ Raises:
+ AssertionError: [Invalid input type]
+ """
+ # Make sure model is MujocoModel type
+ assert isinstance(model, MujocoModel), "Inputted model must be of type MujocoModel; got type {} instead!".format(
+ type(model)
+ )
+ contact_set = set()
+ for contact in sim.data.contact[: sim.data.ncon]:
+ # check contact geom in geoms; add to contact set if match is found
+ g1, g2 = sim.model.geom_id2name(contact.geom1), sim.model.geom_id2name(contact.geom2)
+ if g1 in model.contact_geoms and g2 not in model.contact_geoms:
+ contact_set.add(g2)
+ elif g2 in model.contact_geoms and g1 not in model.contact_geoms:
+ contact_set.add(g1)
+ return contact_set
diff --git a/phantom/submodules/phantom-robosuite/robosuite/utils/transform_utils.py b/phantom/submodules/phantom-robosuite/robosuite/utils/transform_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac371c09da962dcceea8a376876fc1d46c0e52da
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/utils/transform_utils.py
@@ -0,0 +1,929 @@
+"""
+Utility functions of matrix and vector transformations.
+
+NOTE: convention for quaternions is (x, y, z, w)
+"""
+
+import math
+
+import numpy as np
+
+from robosuite.utils.numba import jit_decorator
+
+PI = np.pi
+EPS = np.finfo(float).eps * 4.0
+
+# axis sequences for Euler angles
+_NEXT_AXIS = [1, 2, 0, 1]
+
+# map axes strings to/from tuples of inner axis, parity, repetition, frame
+_AXES2TUPLE = {
+ "sxyz": (0, 0, 0, 0),
+ "sxyx": (0, 0, 1, 0),
+ "sxzy": (0, 1, 0, 0),
+ "sxzx": (0, 1, 1, 0),
+ "syzx": (1, 0, 0, 0),
+ "syzy": (1, 0, 1, 0),
+ "syxz": (1, 1, 0, 0),
+ "syxy": (1, 1, 1, 0),
+ "szxy": (2, 0, 0, 0),
+ "szxz": (2, 0, 1, 0),
+ "szyx": (2, 1, 0, 0),
+ "szyz": (2, 1, 1, 0),
+ "rzyx": (0, 0, 0, 1),
+ "rxyx": (0, 0, 1, 1),
+ "ryzx": (0, 1, 0, 1),
+ "rxzx": (0, 1, 1, 1),
+ "rxzy": (1, 0, 0, 1),
+ "ryzy": (1, 0, 1, 1),
+ "rzxy": (1, 1, 0, 1),
+ "ryxy": (1, 1, 1, 1),
+ "ryxz": (2, 0, 0, 1),
+ "rzxz": (2, 0, 1, 1),
+ "rxyz": (2, 1, 0, 1),
+ "rzyz": (2, 1, 1, 1),
+}
+
+_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
+
+
+def convert_quat(q, to="xyzw"):
+ """
+ Converts quaternion from one convention to another.
+ The convention to convert TO is specified as an optional argument.
+ If to == 'xyzw', then the input is in 'wxyz' format, and vice-versa.
+
+ Args:
+ q (np.array): a 4-dim array corresponding to a quaternion
+ to (str): either 'xyzw' or 'wxyz', determining which convention to convert to.
+ """
+ if to == "xyzw":
+ return q[[1, 2, 3, 0]]
+ if to == "wxyz":
+ return q[[3, 0, 1, 2]]
+ raise Exception("convert_quat: choose a valid `to` argument (xyzw or wxyz)")
+
+
+def quat_multiply(quaternion1, quaternion0):
+ """
+ Return multiplication of two quaternions (q1 * q0).
+
+ E.g.:
+ >>> q = quat_multiply([1, -2, 3, 4], [-5, 6, 7, 8])
+ >>> np.allclose(q, [-44, -14, 48, 28])
+ True
+
+ Args:
+ quaternion1 (np.array): (x,y,z,w) quaternion
+ quaternion0 (np.array): (x,y,z,w) quaternion
+
+ Returns:
+ np.array: (x,y,z,w) multiplied quaternion
+ """
+ x0, y0, z0, w0 = quaternion0
+ x1, y1, z1, w1 = quaternion1
+ return np.array(
+ (
+ x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0,
+ -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0,
+ x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0,
+ -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0,
+ ),
+ dtype=np.float32,
+ )
+
+
+def quat_conjugate(quaternion):
+ """
+ Return conjugate of quaternion.
+
+ E.g.:
+ >>> q0 = random_quaternion()
+ >>> q1 = quat_conjugate(q0)
+ >>> q1[3] == q0[3] and all(q1[:3] == -q0[:3])
+ True
+
+ Args:
+ quaternion (np.array): (x,y,z,w) quaternion
+
+ Returns:
+ np.array: (x,y,z,w) quaternion conjugate
+ """
+ return np.array(
+ (-quaternion[0], -quaternion[1], -quaternion[2], quaternion[3]),
+ dtype=np.float32,
+ )
+
+
+def quat_inverse(quaternion):
+ """
+ Return inverse of quaternion.
+
+ E.g.:
+ >>> q0 = random_quaternion()
+ >>> q1 = quat_inverse(q0)
+ >>> np.allclose(quat_multiply(q0, q1), [0, 0, 0, 1])
+ True
+
+ Args:
+ quaternion (np.array): (x,y,z,w) quaternion
+
+ Returns:
+ np.array: (x,y,z,w) quaternion inverse
+ """
+ return quat_conjugate(quaternion) / np.dot(quaternion, quaternion)
+
+
+def quat_distance(quaternion1, quaternion0):
+ """
+ Returns distance between two quaternions, such that distance * quaternion0 = quaternion1
+
+ Args:
+ quaternion1 (np.array): (x,y,z,w) quaternion
+ quaternion0 (np.array): (x,y,z,w) quaternion
+
+ Returns:
+ np.array: (x,y,z,w) quaternion distance
+ """
+ return quat_multiply(quaternion1, quat_inverse(quaternion0))
+
+
+def quat_slerp(quat0, quat1, fraction, shortestpath=True):
+ """
+ Return spherical linear interpolation between two quaternions.
+
+ E.g.:
+ >>> q0 = random_quat()
+ >>> q1 = random_quat()
+ >>> q = quat_slerp(q0, q1, 0.0)
+ >>> np.allclose(q, q0)
+ True
+
+ >>> q = quat_slerp(q0, q1, 1.0)
+ >>> np.allclose(q, q1)
+ True
+
+ >>> q = quat_slerp(q0, q1, 0.5)
+ >>> angle = math.acos(np.dot(q0, q))
+ >>> np.allclose(2.0, math.acos(np.dot(q0, q1)) / angle) or \
+ np.allclose(2.0, math.acos(-np.dot(q0, q1)) / angle)
+ True
+
+ Args:
+ quat0 (np.array): (x,y,z,w) quaternion startpoint
+ quat1 (np.array): (x,y,z,w) quaternion endpoint
+ fraction (float): fraction of interpolation to calculate
+ shortestpath (bool): If True, will calculate the shortest path
+
+ Returns:
+ np.array: (x,y,z,w) quaternion distance
+ """
+ q0 = unit_vector(quat0[:4])
+ q1 = unit_vector(quat1[:4])
+ if fraction == 0.0:
+ return q0
+ elif fraction == 1.0:
+ return q1
+ d = np.dot(q0, q1)
+ if abs(abs(d) - 1.0) < EPS:
+ return q0
+ if shortestpath and d < 0.0:
+ # invert rotation
+ d = -d
+ q1 *= -1.0
+ angle = math.acos(np.clip(d, -1, 1))
+ if abs(angle) < EPS:
+ return q0
+ isin = 1.0 / math.sin(angle)
+ q0 *= math.sin((1.0 - fraction) * angle) * isin
+ q1 *= math.sin(fraction * angle) * isin
+ q0 += q1
+ return q0
+
+
+def random_quat(rand=None):
+ """
+ Return uniform random unit quaternion.
+
+ E.g.:
+ >>> q = random_quat()
+ >>> np.allclose(1.0, vector_norm(q))
+ True
+ >>> q = random_quat(np.random.random(3))
+ >>> q.shape
+ (4,)
+
+ Args:
+ rand (3-array or None): If specified, must be three independent random variables that are uniformly distributed
+ between 0 and 1.
+
+ Returns:
+ np.array: (x,y,z,w) random quaternion
+ """
+ if rand is None:
+ rand = np.random.rand(3)
+ else:
+ assert len(rand) == 3
+ r1 = np.sqrt(1.0 - rand[0])
+ r2 = np.sqrt(rand[0])
+ pi2 = math.pi * 2.0
+ t1 = pi2 * rand[1]
+ t2 = pi2 * rand[2]
+ return np.array(
+ (np.sin(t1) * r1, np.cos(t1) * r1, np.sin(t2) * r2, np.cos(t2) * r2),
+ dtype=np.float32,
+ )
+
+
+def random_axis_angle(angle_limit=None, random_state=None):
+ """
+ Samples an axis-angle rotation by first sampling a random axis
+ and then sampling an angle. If @angle_limit is provided, the size
+ of the rotation angle is constrained.
+
+ If @random_state is provided (instance of np.random.RandomState), it
+ will be used to generate random numbers.
+
+ Args:
+ angle_limit (None or float): If set, determines magnitude limit of angles to generate
+ random_state (None or RandomState): RNG to use if specified
+
+ Raises:
+ AssertionError: [Invalid RNG]
+ """
+ if angle_limit is None:
+ angle_limit = 2.0 * np.pi
+
+ if random_state is not None:
+ assert isinstance(random_state, np.random.RandomState)
+ npr = random_state
+ else:
+ npr = np.random
+
+ # sample random axis using a normalized sample from spherical Gaussian.
+ # see (http://extremelearning.com.au/how-to-generate-uniformly-random-points-on-n-spheres-and-n-balls/)
+ # for why it works.
+ random_axis = npr.randn(3)
+ random_axis /= np.linalg.norm(random_axis)
+ random_angle = npr.uniform(low=0.0, high=angle_limit)
+ return random_axis, random_angle
+
+
+def vec(values):
+ """
+ Converts value tuple into a numpy vector.
+
+ Args:
+ values (n-array): a tuple of numbers
+
+ Returns:
+ np.array: vector of given values
+ """
+ return np.array(values, dtype=np.float32)
+
+
+def mat4(array):
+ """
+ Converts an array to 4x4 matrix.
+
+ Args:
+ array (n-array): the array in form of vec, list, or tuple
+
+ Returns:
+ np.array: a 4x4 numpy matrix
+ """
+ return np.array(array, dtype=np.float32).reshape((4, 4))
+
+
+def mat2pose(hmat):
+ """
+ Converts a homogeneous 4x4 matrix into pose.
+
+ Args:
+ hmat (np.array): a 4x4 homogeneous matrix
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) (x,y,z) position array in cartesian coordinates
+ - (np.array) (x,y,z,w) orientation array in quaternion form
+ """
+ pos = hmat[:3, 3]
+ orn = mat2quat(hmat[:3, :3])
+ return pos, orn
+
+
+@jit_decorator
+def mat2quat(rmat):
+ """
+ Converts given rotation matrix to quaternion.
+
+ Args:
+ rmat (np.array): 3x3 rotation matrix
+
+ Returns:
+ np.array: (x,y,z,w) float quaternion angles
+ """
+ M = np.asarray(rmat).astype(np.float32)[:3, :3]
+
+ m00 = M[0, 0]
+ m01 = M[0, 1]
+ m02 = M[0, 2]
+ m10 = M[1, 0]
+ m11 = M[1, 1]
+ m12 = M[1, 2]
+ m20 = M[2, 0]
+ m21 = M[2, 1]
+ m22 = M[2, 2]
+ # symmetric matrix K
+ K = np.array(
+ [
+ [m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
+ [m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
+ [m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
+ [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
+ ]
+ )
+ K /= 3.0
+ # quaternion is Eigen vector of K that corresponds to largest eigenvalue
+ w, V = np.linalg.eigh(K)
+ inds = np.array([3, 0, 1, 2])
+ q1 = V[inds, np.argmax(w)]
+ if q1[0] < 0.0:
+ np.negative(q1, q1)
+ inds = np.array([1, 2, 3, 0])
+ return q1[inds]
+
+
+def euler2mat(euler):
+ """
+ Converts euler angles into rotation matrix form
+
+ Args:
+ euler (np.array): (r,p,y) angles
+
+ Returns:
+ np.array: 3x3 rotation matrix
+
+ Raises:
+ AssertionError: [Invalid input shape]
+ """
+
+ euler = np.asarray(euler, dtype=np.float64)
+ assert euler.shape[-1] == 3, "Invalid shaped euler {}".format(euler)
+
+ ai, aj, ak = -euler[..., 2], -euler[..., 1], -euler[..., 0]
+ si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak)
+ ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak)
+ cc, cs = ci * ck, ci * sk
+ sc, ss = si * ck, si * sk
+
+ mat = np.empty(euler.shape[:-1] + (3, 3), dtype=np.float64)
+ mat[..., 2, 2] = cj * ck
+ mat[..., 2, 1] = sj * sc - cs
+ mat[..., 2, 0] = sj * cc + ss
+ mat[..., 1, 2] = cj * sk
+ mat[..., 1, 1] = sj * ss + cc
+ mat[..., 1, 0] = sj * cs - sc
+ mat[..., 0, 2] = -sj
+ mat[..., 0, 1] = cj * si
+ mat[..., 0, 0] = cj * ci
+ return mat
+
+
+def mat2euler(rmat, axes="sxyz"):
+ """
+ Converts given rotation matrix to euler angles in radian.
+
+ Args:
+ rmat (np.array): 3x3 rotation matrix
+ axes (str): One of 24 axis sequences as string or encoded tuple (see top of this module)
+
+ Returns:
+ np.array: (r,p,y) converted euler angles in radian vec3 float
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
+ except (AttributeError, KeyError):
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis
+ j = _NEXT_AXIS[i + parity]
+ k = _NEXT_AXIS[i - parity + 1]
+
+ M = np.array(rmat, dtype=np.float32, copy=False)[:3, :3]
+ if repetition:
+ sy = math.sqrt(M[i, j] * M[i, j] + M[i, k] * M[i, k])
+ if sy > EPS:
+ ax = math.atan2(M[i, j], M[i, k])
+ ay = math.atan2(sy, M[i, i])
+ az = math.atan2(M[j, i], -M[k, i])
+ else:
+ ax = math.atan2(-M[j, k], M[j, j])
+ ay = math.atan2(sy, M[i, i])
+ az = 0.0
+ else:
+ cy = math.sqrt(M[i, i] * M[i, i] + M[j, i] * M[j, i])
+ if cy > EPS:
+ ax = math.atan2(M[k, j], M[k, k])
+ ay = math.atan2(-M[k, i], cy)
+ az = math.atan2(M[j, i], M[i, i])
+ else:
+ ax = math.atan2(-M[j, k], M[j, j])
+ ay = math.atan2(-M[k, i], cy)
+ az = 0.0
+
+ if parity:
+ ax, ay, az = -ax, -ay, -az
+ if frame:
+ ax, az = az, ax
+ return vec((ax, ay, az))
+
+
+def pose2mat(pose):
+ """
+ Converts pose to homogeneous matrix.
+
+ Args:
+ pose (2-tuple): a (pos, orn) tuple where pos is vec3 float cartesian, and
+ orn is vec4 float quaternion.
+
+ Returns:
+ np.array: 4x4 homogeneous matrix
+ """
+ homo_pose_mat = np.zeros((4, 4), dtype=np.float32)
+ homo_pose_mat[:3, :3] = quat2mat(pose[1])
+ homo_pose_mat[:3, 3] = np.array(pose[0], dtype=np.float32)
+ homo_pose_mat[3, 3] = 1.0
+ return homo_pose_mat
+
+
+@jit_decorator
+def quat2mat(quaternion):
+ """
+ Converts given quaternion to matrix.
+
+ Args:
+ quaternion (np.array): (x,y,z,w) vec4 float angles
+
+ Returns:
+ np.array: 3x3 rotation matrix
+ """
+ # awkward semantics for use with numba
+ inds = np.array([3, 0, 1, 2])
+ q = np.asarray(quaternion).copy().astype(np.float32)[inds]
+
+ n = np.dot(q, q)
+ if n < EPS:
+ return np.identity(3)
+ q *= math.sqrt(2.0 / n)
+ q2 = np.outer(q, q)
+ return np.array(
+ [
+ [1.0 - q2[2, 2] - q2[3, 3], q2[1, 2] - q2[3, 0], q2[1, 3] + q2[2, 0]],
+ [q2[1, 2] + q2[3, 0], 1.0 - q2[1, 1] - q2[3, 3], q2[2, 3] - q2[1, 0]],
+ [q2[1, 3] - q2[2, 0], q2[2, 3] + q2[1, 0], 1.0 - q2[1, 1] - q2[2, 2]],
+ ]
+ )
+
+
+def quat2axisangle(quat):
+ """
+ Converts quaternion to axis-angle format.
+ Returns a unit vector direction scaled by its angle in radians.
+
+ Args:
+ quat (np.array): (x,y,z,w) vec4 float angles
+
+ Returns:
+ np.array: (ax,ay,az) axis-angle exponential coordinates
+ """
+ # clip quaternion
+ if quat[3] > 1.0:
+ quat[3] = 1.0
+ elif quat[3] < -1.0:
+ quat[3] = -1.0
+
+ den = np.sqrt(1.0 - quat[3] * quat[3])
+ if math.isclose(den, 0.0):
+ # This is (close to) a zero degree rotation, immediately return
+ return np.zeros(3)
+
+ return (quat[:3] * 2.0 * math.acos(quat[3])) / den
+
+
+def axisangle2quat(vec):
+ """
+ Converts scaled axis-angle to quat.
+
+ Args:
+ vec (np.array): (ax,ay,az) axis-angle exponential coordinates
+
+ Returns:
+ np.array: (x,y,z,w) vec4 float angles
+ """
+ # Grab angle
+ angle = np.linalg.norm(vec)
+
+ # handle zero-rotation case
+ if math.isclose(angle, 0.0):
+ return np.array([0.0, 0.0, 0.0, 1.0])
+
+ # make sure that axis is a unit vector
+ axis = vec / angle
+
+ q = np.zeros(4)
+ q[3] = np.cos(angle / 2.0)
+ q[:3] = axis * np.sin(angle / 2.0)
+ return q
+
+
+def pose_in_A_to_pose_in_B(pose_A, pose_A_in_B):
+ """
+ Converts a homogenous matrix corresponding to a point C in frame A
+ to a homogenous matrix corresponding to the same point C in frame B.
+
+ Args:
+ pose_A (np.array): 4x4 matrix corresponding to the pose of C in frame A
+ pose_A_in_B (np.array): 4x4 matrix corresponding to the pose of A in frame B
+
+ Returns:
+ np.array: 4x4 matrix corresponding to the pose of C in frame B
+ """
+
+ # pose of A in B takes a point in A and transforms it to a point in C.
+
+ # pose of C in B = pose of A in B * pose of C in A
+ # take a point in C, transform it to A, then to B
+ # T_B^C = T_A^C * T_B^A
+ return pose_A_in_B.dot(pose_A)
+
+
+def pose_inv(pose):
+ """
+ Computes the inverse of a homogeneous matrix corresponding to the pose of some
+ frame B in frame A. The inverse is the pose of frame A in frame B.
+
+ Args:
+ pose (np.array): 4x4 matrix for the pose to inverse
+
+ Returns:
+ np.array: 4x4 matrix for the inverse pose
+ """
+
+ # Note, the inverse of a pose matrix is the following
+ # [R t; 0 1]^-1 = [R.T -R.T*t; 0 1]
+
+ # Intuitively, this makes sense.
+ # The original pose matrix translates by t, then rotates by R.
+ # We just invert the rotation by applying R-1 = R.T, and also translate back.
+ # Since we apply translation first before rotation, we need to translate by
+ # -t in the original frame, which is -R-1*t in the new frame, and then rotate back by
+ # R-1 to align the axis again.
+
+ pose_inv = np.zeros((4, 4))
+ pose_inv[:3, :3] = pose[:3, :3].T
+ pose_inv[:3, 3] = -pose_inv[:3, :3].dot(pose[:3, 3])
+ pose_inv[3, 3] = 1.0
+ return pose_inv
+
+
+def _skew_symmetric_translation(pos_A_in_B):
+ """
+ Helper function to get a skew symmetric translation matrix for converting quantities
+ between frames.
+
+ Args:
+ pos_A_in_B (np.array): (x,y,z) position of A in frame B
+
+ Returns:
+ np.array: 3x3 skew symmetric translation matrix
+ """
+ return np.array(
+ [
+ 0.0,
+ -pos_A_in_B[2],
+ pos_A_in_B[1],
+ pos_A_in_B[2],
+ 0.0,
+ -pos_A_in_B[0],
+ -pos_A_in_B[1],
+ pos_A_in_B[0],
+ 0.0,
+ ]
+ ).reshape((3, 3))
+
+
+def vel_in_A_to_vel_in_B(vel_A, ang_vel_A, pose_A_in_B):
+ """
+ Converts linear and angular velocity of a point in frame A to the equivalent in frame B.
+
+ Args:
+ vel_A (np.array): (vx,vy,vz) linear velocity in A
+ ang_vel_A (np.array): (wx,wy,wz) angular velocity in A
+ pose_A_in_B (np.array): 4x4 matrix corresponding to the pose of A in frame B
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) (vx,vy,vz) linear velocities in frame B
+ - (np.array) (wx,wy,wz) angular velocities in frame B
+ """
+ pos_A_in_B = pose_A_in_B[:3, 3]
+ rot_A_in_B = pose_A_in_B[:3, :3]
+ skew_symm = _skew_symmetric_translation(pos_A_in_B)
+ vel_B = rot_A_in_B.dot(vel_A) + skew_symm.dot(rot_A_in_B.dot(ang_vel_A))
+ ang_vel_B = rot_A_in_B.dot(ang_vel_A)
+ return vel_B, ang_vel_B
+
+
+def force_in_A_to_force_in_B(force_A, torque_A, pose_A_in_B):
+ """
+ Converts linear and rotational force at a point in frame A to the equivalent in frame B.
+
+ Args:
+ force_A (np.array): (fx,fy,fz) linear force in A
+ torque_A (np.array): (tx,ty,tz) rotational force (moment) in A
+ pose_A_in_B (np.array): 4x4 matrix corresponding to the pose of A in frame B
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) (fx,fy,fz) linear forces in frame B
+ - (np.array) (tx,ty,tz) moments in frame B
+ """
+ pos_A_in_B = pose_A_in_B[:3, 3]
+ rot_A_in_B = pose_A_in_B[:3, :3]
+ skew_symm = _skew_symmetric_translation(pos_A_in_B)
+ force_B = rot_A_in_B.T.dot(force_A)
+ torque_B = -rot_A_in_B.T.dot(skew_symm.dot(force_A)) + rot_A_in_B.T.dot(torque_A)
+ return force_B, torque_B
+
+
+def rotation_matrix(angle, direction, point=None):
+ """
+ Returns matrix to rotate about axis defined by point and direction.
+
+ E.g.:
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
+ >>> direc = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> R1 = rotation_matrix(angle-2*math.pi, direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> R1 = rotation_matrix(-angle, -direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+
+ >>> I = numpy.identity(4, numpy.float32)
+ >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc))
+ True
+
+ >>> numpy.allclose(2., numpy.trace(rotation_matrix(math.pi/2,
+ ... direc, point)))
+ True
+
+ Args:
+ angle (float): Magnitude of rotation
+ direction (np.array): (ax,ay,az) axis about which to rotate
+ point (None or np.array): If specified, is the (x,y,z) point about which the rotation will occur
+
+ Returns:
+ np.array: 4x4 homogeneous matrix that includes the desired rotation
+ """
+ sina = math.sin(angle)
+ cosa = math.cos(angle)
+ direction = unit_vector(direction[:3])
+ # rotation matrix around unit vector
+ R = np.array(((cosa, 0.0, 0.0), (0.0, cosa, 0.0), (0.0, 0.0, cosa)), dtype=np.float32)
+ R += np.outer(direction, direction) * (1.0 - cosa)
+ direction *= sina
+ R += np.array(
+ (
+ (0.0, -direction[2], direction[1]),
+ (direction[2], 0.0, -direction[0]),
+ (-direction[1], direction[0], 0.0),
+ ),
+ dtype=np.float32,
+ )
+ M = np.identity(4)
+ M[:3, :3] = R
+ if point is not None:
+ # rotation not around origin
+ point = np.array(point[:3], dtype=np.float32, copy=False)
+ M[:3, 3] = point - np.dot(R, point)
+ return M
+
+
+def clip_translation(dpos, limit):
+ """
+ Limits a translation (delta position) to a specified limit
+
+ Scales down the norm of the dpos to 'limit' if norm(dpos) > limit, else returns immediately
+
+ Args:
+ dpos (n-array): n-dim Translation being clipped (e,g.: (x, y, z)) -- numpy array
+ limit (float): Value to limit translation by -- magnitude (scalar, in same units as input)
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) Clipped translation (same dimension as inputs)
+ - (bool) whether the value was clipped or not
+ """
+ input_norm = np.linalg.norm(dpos)
+ return (dpos * limit / input_norm, True) if input_norm > limit else (dpos, False)
+
+
+def clip_rotation(quat, limit):
+ """
+ Limits a (delta) rotation to a specified limit
+
+ Converts rotation to axis-angle, clips, then re-converts back into quaternion
+
+ Args:
+ quat (np.array): (x,y,z,w) rotation being clipped
+ limit (float): Value to limit rotation by -- magnitude (scalar, in radians)
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) Clipped rotation quaternion (x, y, z, w)
+ - (bool) whether the value was clipped or not
+ """
+ clipped = False
+
+ # First, normalize the quaternion
+ quat = quat / np.linalg.norm(quat)
+
+ den = np.sqrt(max(1 - quat[3] * quat[3], 0))
+ if den == 0:
+ # This is a zero degree rotation, immediately return
+ return quat, clipped
+ else:
+ # This is all other cases
+ x = quat[0] / den
+ y = quat[1] / den
+ z = quat[2] / den
+ a = 2 * math.acos(quat[3])
+
+ # Clip rotation if necessary and return clipped quat
+ if abs(a) > limit:
+ a = limit * np.sign(a) / 2
+ sa = math.sin(a)
+ ca = math.cos(a)
+ quat = np.array([x * sa, y * sa, z * sa, ca])
+ clipped = True
+
+ return quat, clipped
+
+
+def make_pose(translation, rotation):
+ """
+ Makes a homogeneous pose matrix from a translation vector and a rotation matrix.
+
+ Args:
+ translation (np.array): (x,y,z) translation value
+ rotation (np.array): a 3x3 matrix representing rotation
+
+ Returns:
+ pose (np.array): a 4x4 homogeneous matrix
+ """
+ pose = np.zeros((4, 4))
+ pose[:3, :3] = rotation
+ pose[:3, 3] = translation
+ pose[3, 3] = 1.0
+ return pose
+
+
+def unit_vector(data, axis=None, out=None):
+ """
+ Returns ndarray normalized by length, i.e. eucledian norm, along axis.
+
+ E.g.:
+ >>> v0 = numpy.random.random(3)
+ >>> v1 = unit_vector(v0)
+ >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0))
+ True
+
+ >>> v0 = numpy.random.rand(5, 4, 3)
+ >>> v1 = unit_vector(v0, axis=-1)
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2)
+ >>> numpy.allclose(v1, v2)
+ True
+
+ >>> v1 = unit_vector(v0, axis=1)
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1)
+ >>> numpy.allclose(v1, v2)
+ True
+
+ >>> v1 = numpy.empty((5, 4, 3), dtype=numpy.float32)
+ >>> unit_vector(v0, axis=1, out=v1)
+ >>> numpy.allclose(v1, v2)
+ True
+
+ >>> list(unit_vector([]))
+ []
+
+ >>> list(unit_vector([1.0]))
+ [1.0]
+
+ Args:
+ data (np.array): data to normalize
+ axis (None or int): If specified, determines specific axis along data to normalize
+ out (None or np.array): If specified, will store computation in this variable
+
+ Returns:
+ None or np.array: If @out is not specified, will return normalized vector. Otherwise, stores the output in @out
+ """
+ if out is None:
+ data = np.array(data, dtype=np.float32, copy=True)
+ if data.ndim == 1:
+ data /= math.sqrt(np.dot(data, data))
+ return data
+ else:
+ if out is not data:
+ out[:] = np.array(data, copy=False)
+ data = out
+ length = np.atleast_1d(np.sum(data * data, axis))
+ np.sqrt(length, length)
+ if axis is not None:
+ length = np.expand_dims(length, axis)
+ data /= length
+ if out is None:
+ return data
+
+
+def get_orientation_error(target_orn, current_orn):
+ """
+ Returns the difference between two quaternion orientations as a 3 DOF numpy array.
+ For use in an impedance controller / task-space PD controller.
+
+ Args:
+ target_orn (np.array): (x, y, z, w) desired quaternion orientation
+ current_orn (np.array): (x, y, z, w) current quaternion orientation
+
+ Returns:
+ orn_error (np.array): (ax,ay,az) current orientation error, corresponds to
+ (target_orn - current_orn)
+ """
+ current_orn = np.array([current_orn[3], current_orn[0], current_orn[1], current_orn[2]])
+ target_orn = np.array([target_orn[3], target_orn[0], target_orn[1], target_orn[2]])
+
+ pinv = np.zeros((3, 4))
+ pinv[0, :] = [-current_orn[1], current_orn[0], -current_orn[3], current_orn[2]]
+ pinv[1, :] = [-current_orn[2], current_orn[3], current_orn[0], -current_orn[1]]
+ pinv[2, :] = [-current_orn[3], -current_orn[2], current_orn[1], current_orn[0]]
+ orn_error = 2.0 * pinv.dot(np.array(target_orn))
+ return orn_error
+
+
+def get_pose_error(target_pose, current_pose):
+ """
+ Computes the error corresponding to target pose - current pose as a 6-dim vector.
+ The first 3 components correspond to translational error while the last 3 components
+ correspond to the rotational error.
+
+ Args:
+ target_pose (np.array): a 4x4 homogenous matrix for the target pose
+ current_pose (np.array): a 4x4 homogenous matrix for the current pose
+
+ Returns:
+ np.array: 6-dim pose error.
+ """
+ error = np.zeros(6)
+
+ # compute translational error
+ target_pos = target_pose[:3, 3]
+ current_pos = current_pose[:3, 3]
+ pos_err = target_pos - current_pos
+
+ # compute rotational error
+ r1 = current_pose[:3, 0]
+ r2 = current_pose[:3, 1]
+ r3 = current_pose[:3, 2]
+ r1d = target_pose[:3, 0]
+ r2d = target_pose[:3, 1]
+ r3d = target_pose[:3, 2]
+ rot_err = 0.5 * (np.cross(r1, r1d) + np.cross(r2, r2d) + np.cross(r3, r3d))
+
+ error[:3] = pos_err
+ error[3:] = rot_err
+ return error
+
+
+@jit_decorator
+def matrix_inverse(matrix):
+ """
+ Helper function to have an efficient matrix inversion function.
+
+ Args:
+ matrix (np.array): 2d-array representing a matrix
+
+ Returns:
+ np.array: 2d-array representing the matrix inverse
+ """
+ return np.linalg.inv(matrix)
diff --git a/phantom/submodules/phantom-robosuite/robosuite/wrappers/__init__.py b/phantom/submodules/phantom-robosuite/robosuite/wrappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..282a66a0702dda49c0a4d2822bb5c994d85d8cd4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/wrappers/__init__.py
@@ -0,0 +1,10 @@
+from robosuite.wrappers.wrapper import Wrapper
+from robosuite.wrappers.data_collection_wrapper import DataCollectionWrapper
+from robosuite.wrappers.demo_sampler_wrapper import DemoSamplerWrapper
+from robosuite.wrappers.domain_randomization_wrapper import DomainRandomizationWrapper
+from robosuite.wrappers.visualization_wrapper import VisualizationWrapper
+
+try:
+ from robosuite.wrappers.gym_wrapper import GymWrapper
+except:
+ print("Warning: make sure gym is installed if you want to use the GymWrapper.")
diff --git a/phantom/submodules/phantom-robosuite/robosuite/wrappers/data_collection_wrapper.py b/phantom/submodules/phantom-robosuite/robosuite/wrappers/data_collection_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..60602aa6d005b677d2416e0a64638a0b4109f7cd
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/wrappers/data_collection_wrapper.py
@@ -0,0 +1,188 @@
+"""
+This file implements a wrapper for saving simulation states to disk.
+This data collection wrapper is useful for collecting demonstrations.
+"""
+
+import os
+import time
+
+import numpy as np
+
+from robosuite.utils.mjcf_utils import save_sim_model
+from robosuite.wrappers import Wrapper
+
+
+class DataCollectionWrapper(Wrapper):
+ def __init__(self, env, directory, collect_freq=1, flush_freq=100):
+ """
+ Initializes the data collection wrapper.
+
+ Args:
+ env (MujocoEnv): The environment to monitor.
+ directory (str): Where to store collected data.
+ collect_freq (int): How often to save simulation state, in terms of environment steps.
+ flush_freq (int): How frequently to dump data to disk, in terms of environment steps.
+ """
+ super().__init__(env)
+
+ # the base directory for all logging
+ self.directory = directory
+
+ # in-memory cache for simulation states and action info
+ self.states = []
+ self.action_infos = [] # stores information about actions taken
+ self.successful = False # stores success state of demonstration
+
+ # how often to save simulation state, in terms of environment steps
+ self.collect_freq = collect_freq
+
+ # how frequently to dump data to disk, in terms of environment steps
+ self.flush_freq = flush_freq
+
+ if not os.path.exists(directory):
+ print("DataCollectionWrapper: making new directory at {}".format(directory))
+ os.makedirs(directory)
+
+ # store logging directory for current episode
+ self.ep_directory = None
+
+ # remember whether any environment interaction has occurred
+ self.has_interaction = False
+
+ # some variables for remembering the current episode's initial state and model xml
+ self._current_task_instance_state = None
+ self._current_task_instance_xml = None
+
+ def _start_new_episode(self):
+ """
+ Bookkeeping to do at the start of each new episode.
+ """
+
+ # flush any data left over from the previous episode if any interactions have happened
+ if self.has_interaction:
+ self._flush()
+
+ # timesteps in current episode
+ self.t = 0
+ self.has_interaction = False
+
+ # save the task instance (will be saved on the first env interaction)
+ self._current_task_instance_xml = self.env.sim.model.get_xml()
+ self._current_task_instance_state = np.array(self.env.sim.get_state().flatten())
+
+ # trick for ensuring that we can play MuJoCo demonstrations back
+ # deterministically by using the recorded actions open loop
+ self.env.reset_from_xml_string(self._current_task_instance_xml)
+ self.env.sim.reset()
+ self.env.sim.set_state_from_flattened(self._current_task_instance_state)
+ self.env.sim.forward()
+
+ def _on_first_interaction(self):
+ """
+ Bookkeeping for first timestep of episode.
+ This function is necessary to make sure that logging only happens after the first
+ step call to the simulation, instead of on the reset (people tend to call
+ reset more than is necessary in code).
+
+ Raises:
+ AssertionError: [Episode path already exists]
+ """
+
+ self.has_interaction = True
+
+ # create a directory with a timestamp
+ t1, t2 = str(time.time()).split(".")
+ self.ep_directory = os.path.join(self.directory, "ep_{}_{}".format(t1, t2))
+ assert not os.path.exists(self.ep_directory)
+ print("DataCollectionWrapper: making folder at {}".format(self.ep_directory))
+ os.makedirs(self.ep_directory)
+
+ # save the model xml
+ xml_path = os.path.join(self.ep_directory, "model.xml")
+ with open(xml_path, "w") as f:
+ f.write(self._current_task_instance_xml)
+
+ # save initial state and action
+ assert len(self.states) == 0
+ self.states.append(self._current_task_instance_state)
+
+ def _flush(self):
+ """
+ Method to flush internal state to disk.
+ """
+ t1, t2 = str(time.time()).split(".")
+ state_path = os.path.join(self.ep_directory, "state_{}_{}.npz".format(t1, t2))
+ if hasattr(self.env, "unwrapped"):
+ env_name = self.env.unwrapped.__class__.__name__
+ else:
+ env_name = self.env.__class__.__name__
+ np.savez(
+ state_path,
+ states=np.array(self.states),
+ action_infos=self.action_infos,
+ successful=self.successful,
+ env=env_name,
+ )
+ self.states = []
+ self.action_infos = []
+ self.successful = False
+
+ def reset(self):
+ """
+ Extends vanilla reset() function call to accommodate data collection
+
+ Returns:
+ OrderedDict: Environment observation space after reset occurs
+ """
+ ret = super().reset()
+ self._start_new_episode()
+ return ret
+
+ def step(self, action):
+ """
+ Extends vanilla step() function call to accommodate data collection
+
+ Args:
+ action (np.array): Action to take in environment
+
+ Returns:
+ 4-tuple:
+
+ - (OrderedDict) observations from the environment
+ - (float) reward from the environment
+ - (bool) whether the current episode is completed or not
+ - (dict) misc information
+ """
+ ret = super().step(action)
+ self.t += 1
+
+ # on the first time step, make directories for logging
+ if not self.has_interaction:
+ self._on_first_interaction()
+
+ # collect the current simulation state if necessary
+ if self.t % self.collect_freq == 0:
+ state = self.env.sim.get_state().flatten()
+ self.states.append(state)
+
+ info = {}
+ info["actions"] = np.array(action)
+ self.action_infos.append(info)
+
+ # check if the demonstration is successful
+ if self.env._check_success():
+ self.successful = True
+
+ # flush collected data to disk if necessary
+ if self.t % self.flush_freq == 0:
+ self._flush()
+
+ return ret
+
+ def close(self):
+ """
+ Override close method in order to flush left over data
+ """
+ if self.has_interaction:
+ self._flush()
+ self.env.close()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/wrappers/demo_sampler_wrapper.py b/phantom/submodules/phantom-robosuite/robosuite/wrappers/demo_sampler_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..227045fa52c0d3b01ca563886ca56f6eae1593f5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/wrappers/demo_sampler_wrapper.py
@@ -0,0 +1,316 @@
+"""
+This file contains a wrapper for sampling environment states
+from a set of demonstrations on every reset. The main use case is for
+altering the start state distribution of training episodes for
+learning RL policies.
+"""
+
+import os
+import random
+import time
+
+import h5py
+import numpy as np
+
+from robosuite.wrappers import Wrapper
+
+
+class DemoSamplerWrapper(Wrapper):
+ """
+ Initializes a wrapper that provides support for resetting the environment
+ state to one from a demonstration. It also supports curriculums for
+ altering how often to sample from demonstration vs. sampling a reset
+ state from the environment.
+
+ Args:
+ env (MujocoEnv): The environment to wrap.
+
+ demo_path (str): The path to the folder containing the demonstrations.
+ There should be a `demo.hdf5` file and a folder named `models` with
+ all of the stored model xml files from the demonstrations.
+
+ need_xml (bool): If True, the mujoco model needs to be reloaded when
+ sampling a state from a demonstration. This could be because every
+ demonstration was taken under varied object properties, for example.
+ In this case, every sampled state comes with a corresponding xml to
+ be used for the environment reset.
+
+ num_traj (int): If provided, subsample @number demonstrations from the
+ provided set of demonstrations instead of using all of them.
+
+ sampling_schemes (list of str): A list of sampling schemes
+ to be used. The following strings are valid schemes:
+
+ `'random'`: sample a reset state directly from the wrapped environment
+
+ `'uniform'`: sample a state from a demonstration uniformly at random
+
+ `'forward'`: sample a state from a window that grows progressively from
+ the start of demonstrations
+
+ `'reverse'`: sample a state from a window that grows progressively from
+ the end of demonstrations
+
+ scheme_ratios (list of float --> np.array): A list of probability values to
+ assign to each member of @sampling_schemes. Must be non-negative and
+ sum to 1.
+
+ open_loop_increment_freq (int): How frequently to increase
+ the window size in open loop schemes ("forward" and "reverse"). The
+ window size will increase by @open_loop_window_increment every
+ @open_loop_increment_freq samples. Only samples that are generated
+ by open loop schemes contribute to this count.
+
+ open_loop_initial_window_width (int): The width of the initial sampling
+ window, in terms of number of demonstration time steps, for
+ open loop schemes.
+
+ open_loop_window_increment (int): The window size will increase by
+ @open_loop_window_increment every @open_loop_increment_freq samples.
+ This number is in terms of number of demonstration time steps.
+
+ Raises:
+ AssertionError: [Incompatible envs]
+ AssertionError: [Invalid sampling scheme]
+ AssertionError: [Invalid scheme ratio]
+ """
+
+ def __init__(
+ self,
+ env,
+ demo_path,
+ need_xml=False,
+ num_traj=-1,
+ sampling_schemes=("uniform", "random"),
+ scheme_ratios=(0.9, 0.1),
+ open_loop_increment_freq=100,
+ open_loop_initial_window_width=25,
+ open_loop_window_increment=25,
+ ):
+ super().__init__(env)
+
+ self.demo_path = demo_path
+ hdf5_path = os.path.join(self.demo_path, "demo.hdf5")
+ self.demo_file = h5py.File(hdf5_path, "r")
+
+ # ensure that wrapped env matches the env on which demonstrations were collected
+ env_name = self.demo_file["data"].attrs["env"]
+ assert (
+ env_name == self.unwrapped.__class__.__name__
+ ), "Wrapped env {} does not match env on which demos were collected ({})".format(
+ env.__class__.__name__, env_name
+ )
+
+ # list of all demonstrations episodes
+ self.demo_list = list(self.demo_file["data"].keys())
+
+ # subsample a selection of demonstrations if requested
+ if num_traj > 0:
+ random.seed(3141) # ensure that the same set is sampled every time
+ self.demo_list = random.sample(self.demo_list, num_traj)
+
+ self.need_xml = need_xml
+ self.demo_sampled = 0
+
+ self.sample_method_dict = {
+ "random": "_random_sample",
+ "uniform": "_uniform_sample",
+ "forward": "_forward_sample_open_loop",
+ "reverse": "_reverse_sample_open_loop",
+ }
+
+ self.sampling_schemes = sampling_schemes
+ self.scheme_ratios = np.asarray(scheme_ratios)
+
+ # make sure the list of schemes is valid
+ schemes = self.sample_method_dict.keys()
+ assert np.all([(s in schemes) for s in self.sampling_schemes])
+
+ # make sure the distribution is the correct size
+ assert len(self.sampling_schemes) == len(self.scheme_ratios)
+
+ # make sure the distribution lies in the probability simplex
+ assert np.all(self.scheme_ratios > 0.0)
+ assert sum(self.scheme_ratios) == 1.0
+
+ # open loop configuration
+ self.open_loop_increment_freq = open_loop_increment_freq
+ self.open_loop_window_increment = open_loop_window_increment
+
+ # keep track of window size
+ self.open_loop_window_size = open_loop_initial_window_width
+
+ def reset(self):
+ """
+ Logic for sampling a state from the demonstration and resetting
+ the simulation to that state.
+
+ Returns:
+ OrderedDict: Environment observation space after reset occurs
+ """
+ state = self.sample()
+ if state is None:
+ # None indicates that a normal env reset should occur
+ return self.env.reset()
+ else:
+ if self.need_xml:
+ # reset the simulation from the model if necessary
+ state, xml = state
+ self.env.reset_from_xml_string(xml)
+
+ if isinstance(state, tuple):
+ state = state[0]
+
+ # force simulator state to one from the demo
+ self.sim.set_state_from_flattened(state)
+ self.sim.forward()
+
+ return self.env._get_observation()
+
+ def sample(self):
+ """
+ This is the core sampling method. Samples a state from a
+ demonstration, in accordance with the configuration.
+
+ Returns:
+ None or np.array or 2-tuple: If np.array, is the state sampled from a demo file. If 2-tuple, additionally
+ includes the model xml file
+ """
+
+ # chooses a sampling scheme randomly based on the mixing ratios
+ seed = random.uniform(0, 1)
+ ratio = np.cumsum(self.scheme_ratios)
+ ratio = ratio > seed
+ for i, v in enumerate(ratio):
+ if v:
+ break
+
+ sample_method = getattr(self, self.sample_method_dict[self.sampling_schemes[i]])
+ return sample_method()
+
+ def _random_sample(self):
+ """
+ Sampling method.
+
+ Return None to indicate that the state should be sampled directly
+ from the environment.
+ """
+ return None
+
+ def _uniform_sample(self):
+ """
+ Sampling method.
+
+ First uniformly sample a demonstration from the set of demonstrations.
+ Then uniformly sample a state from the selected demonstration.
+
+ Returns:
+ np.array or 2-tuple: If np.array, is the state sampled from a demo file. If 2-tuple, additionally
+ includes the model xml file
+ """
+
+ # get a random episode index
+ ep_ind = random.choice(self.demo_list)
+
+ # select a flattened mujoco state uniformly from this episode
+ states = self.demo_file["data/{}/states".format(ep_ind)][()]
+ state = random.choice(states)
+
+ if self.need_xml:
+ model_xml = self._xml_for_episode_index(ep_ind)
+ xml = self.env.edit_model_xml(model_xml)
+ return state, xml
+ return state
+
+ def _reverse_sample_open_loop(self):
+ """
+ Sampling method.
+
+ Open loop reverse sampling from demonstrations. Starts by
+ sampling from states near the end of the demonstrations.
+ Increases the window backwards as the number of calls to
+ this sampling method increases at a fixed rate.
+
+ Returns:
+ np.array or 2-tuple: If np.array, is the state sampled from a demo file. If 2-tuple, additionally
+ includes the model xml file
+ """
+
+ # get a random episode index
+ ep_ind = random.choice(self.demo_list)
+
+ # sample uniformly in a window that grows backwards from the end of the demos
+ states = self.demo_file["data/{}/states".format(ep_ind)][()]
+ eps_len = states.shape[0]
+ index = np.random.randint(max(eps_len - self.open_loop_window_size, 0), eps_len)
+ state = states[index]
+
+ # increase window size at a fixed frequency (open loop)
+ self.demo_sampled += 1
+ if self.demo_sampled >= self.open_loop_increment_freq:
+ if self.open_loop_window_size < eps_len:
+ self.open_loop_window_size += self.open_loop_window_increment
+ self.demo_sampled = 0
+
+ if self.need_xml:
+ model_xml = self._xml_for_episode_index(ep_ind)
+ xml = self.env.edit_model_xml(model_xml)
+ return state, xml
+
+ return state
+
+ def _forward_sample_open_loop(self):
+ """
+ Sampling method.
+
+ Open loop forward sampling from demonstrations. Starts by
+ sampling from states near the beginning of the demonstrations.
+ Increases the window forwards as the number of calls to
+ this sampling method increases at a fixed rate.
+
+ Returns:
+ np.array or 2-tuple: If np.array, is the state sampled from a demo file. If 2-tuple, additionally
+ includes the model xml file
+ """
+
+ # get a random episode index
+ ep_ind = random.choice(self.demo_list)
+
+ # sample uniformly in a window that grows forwards from the beginning of the demos
+ states = self.demo_file["data/{}/states".format(ep_ind)][()]
+ eps_len = states.shape[0]
+ index = np.random.randint(0, min(self.open_loop_window_size, eps_len))
+ state = states[index]
+
+ # increase window size at a fixed frequency (open loop)
+ self.demo_sampled += 1
+ if self.demo_sampled >= self.open_loop_increment_freq:
+ if self.open_loop_window_size < eps_len:
+ self.open_loop_window_size += self.open_loop_window_increment
+ self.demo_sampled = 0
+
+ if self.need_xml:
+ model_xml = self._xml_for_episode_index(ep_ind)
+ xml = self.env.edit_model_xml(model_xml)
+ return state, xml
+
+ return state
+
+ def _xml_for_episode_index(self, ep_ind):
+ """
+ Helper method to retrieve the corresponding model xml string
+ for the passed episode index.
+
+ Args:
+ ep_ind (int): Episode index to pull from demo file
+
+ Returns:
+ str: model xml as a string
+ """
+
+ # read the model xml, using the metadata stored in the attribute for this episode
+ model_file = self.demo_file["data/{}".format(ep_ind)].attrs["model_file"]
+ model_path = os.path.join(self.demo_path, "models", model_file)
+ with open(model_path, "r") as model_f:
+ model_xml = model_f.read()
+ return model_xml
diff --git a/phantom/submodules/phantom-robosuite/robosuite/wrappers/domain_randomization_wrapper.py b/phantom/submodules/phantom-robosuite/robosuite/wrappers/domain_randomization_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..70dcd7cb9ac9e77f38b2036d581015d84afb8a32
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/wrappers/domain_randomization_wrapper.py
@@ -0,0 +1,266 @@
+"""
+This file implements a wrapper for facilitating domain randomization over
+robosuite environments.
+"""
+import numpy as np
+
+from robosuite.utils.mjmod import CameraModder, DynamicsModder, LightingModder, TextureModder
+from robosuite.wrappers import Wrapper
+
+DEFAULT_COLOR_ARGS = {
+ "geom_names": None, # all geoms are randomized
+ "randomize_local": True, # sample nearby colors
+ "randomize_material": True, # randomize material reflectance / shininess / specular
+ "local_rgb_interpolation": 0.2,
+ "local_material_interpolation": 0.3,
+ "texture_variations": ["rgb", "checker", "noise", "gradient"], # all texture variation types
+ "randomize_skybox": True, # by default, randomize skybox too
+}
+
+DEFAULT_CAMERA_ARGS = {
+ "camera_names": None, # all cameras are randomized
+ "randomize_position": True,
+ "randomize_rotation": True,
+ "randomize_fovy": True,
+ "position_perturbation_size": 0.01,
+ "rotation_perturbation_size": 0.087,
+ "fovy_perturbation_size": 5.0,
+}
+
+DEFAULT_LIGHTING_ARGS = {
+ "light_names": None, # all lights are randomized
+ "randomize_position": True,
+ "randomize_direction": True,
+ "randomize_specular": True,
+ "randomize_ambient": True,
+ "randomize_diffuse": True,
+ "randomize_active": True,
+ "position_perturbation_size": 0.1,
+ "direction_perturbation_size": 0.35,
+ "specular_perturbation_size": 0.1,
+ "ambient_perturbation_size": 0.1,
+ "diffuse_perturbation_size": 0.1,
+}
+
+DEFAULT_DYNAMICS_ARGS = {
+ # Opt parameters
+ "randomize_density": True,
+ "randomize_viscosity": True,
+ "density_perturbation_ratio": 0.1,
+ "viscosity_perturbation_ratio": 0.1,
+ # Body parameters
+ "body_names": None, # all bodies randomized
+ "randomize_position": True,
+ "randomize_quaternion": True,
+ "randomize_inertia": True,
+ "randomize_mass": True,
+ "position_perturbation_size": 0.0015,
+ "quaternion_perturbation_size": 0.003,
+ "inertia_perturbation_ratio": 0.02,
+ "mass_perturbation_ratio": 0.02,
+ # Geom parameters
+ "geom_names": None, # all geoms randomized
+ "randomize_friction": True,
+ "randomize_solref": True,
+ "randomize_solimp": True,
+ "friction_perturbation_ratio": 0.1,
+ "solref_perturbation_ratio": 0.1,
+ "solimp_perturbation_ratio": 0.1,
+ # Joint parameters
+ "joint_names": None, # all joints randomized
+ "randomize_stiffness": True,
+ "randomize_frictionloss": True,
+ "randomize_damping": True,
+ "randomize_armature": True,
+ "stiffness_perturbation_ratio": 0.1,
+ "frictionloss_perturbation_size": 0.05,
+ "damping_perturbation_size": 0.01,
+ "armature_perturbation_size": 0.01,
+}
+
+
+class DomainRandomizationWrapper(Wrapper):
+ """
+ Wrapper that allows for domain randomization mid-simulation.
+
+ Args:
+ env (MujocoEnv): The environment to wrap.
+
+ seed (int): Integer used to seed all randomizations from this wrapper. It is
+ used to create a np.random.RandomState instance to make sure samples here
+ are isolated from sampling occurring elsewhere in the code. If not provided,
+ will default to using global random state.
+
+ randomize_color (bool): if True, randomize geom colors and texture colors
+
+ randomize_camera (bool): if True, randomize camera locations and parameters
+
+ randomize_lighting (bool): if True, randomize light locations and properties
+
+ randomize_dyanmics (bool): if True, randomize dynamics parameters
+
+ color_randomization_args (dict): Color-specific randomization arguments
+
+ camera_randomization_args (dict): Camera-specific randomization arguments
+
+ lighting_randomization_args (dict): Lighting-specific randomization arguments
+
+ dynamics_randomization_args (dict): Dyanmics-specific randomization arguments
+
+ randomize_on_reset (bool): if True, randomize on every call to @reset. This, in
+ conjunction with setting @randomize_every_n_steps to 0, is useful to
+ generate a new domain per episode.
+
+ randomize_every_n_steps (int): determines how often randomization should occur. Set
+ to 0 if randomization should happen manually (by calling @randomize_domain)
+
+ """
+
+ def __init__(
+ self,
+ env,
+ seed=None,
+ randomize_color=True,
+ randomize_camera=True,
+ randomize_lighting=True,
+ randomize_dynamics=True,
+ color_randomization_args=DEFAULT_COLOR_ARGS,
+ camera_randomization_args=DEFAULT_CAMERA_ARGS,
+ lighting_randomization_args=DEFAULT_LIGHTING_ARGS,
+ dynamics_randomization_args=DEFAULT_DYNAMICS_ARGS,
+ randomize_on_reset=True,
+ randomize_every_n_steps=1,
+ ):
+ super().__init__(env)
+
+ self.seed = seed
+ if seed is not None:
+ self.random_state = np.random.RandomState(seed)
+ else:
+ self.random_state = None
+ self.randomize_color = randomize_color
+ self.randomize_camera = randomize_camera
+ self.randomize_lighting = randomize_lighting
+ self.randomize_dynamics = randomize_dynamics
+ self.color_randomization_args = color_randomization_args
+ self.camera_randomization_args = camera_randomization_args
+ self.lighting_randomization_args = lighting_randomization_args
+ self.dynamics_randomization_args = dynamics_randomization_args
+ self.randomize_on_reset = randomize_on_reset
+ self.randomize_every_n_steps = randomize_every_n_steps
+
+ self.step_counter = 0
+
+ self.modders = []
+
+ if self.randomize_color:
+ self.tex_modder = TextureModder(
+ sim=self.env.sim, random_state=self.random_state, **self.color_randomization_args
+ )
+ self.modders.append(self.tex_modder)
+
+ if self.randomize_camera:
+ self.camera_modder = CameraModder(
+ sim=self.env.sim,
+ random_state=self.random_state,
+ **self.camera_randomization_args,
+ )
+ self.modders.append(self.camera_modder)
+
+ if self.randomize_lighting:
+ self.light_modder = LightingModder(
+ sim=self.env.sim,
+ random_state=self.random_state,
+ **self.lighting_randomization_args,
+ )
+ self.modders.append(self.light_modder)
+
+ if self.randomize_dynamics:
+ self.dynamics_modder = DynamicsModder(
+ sim=self.env.sim,
+ random_state=self.random_state,
+ **self.dynamics_randomization_args,
+ )
+ self.modders.append(self.dynamics_modder)
+
+ self.save_default_domain()
+
+ def reset(self):
+ """
+ Extends superclass method to reset the domain randomizer.
+
+ Returns:
+ OrderedDict: Environment observation space after reset occurs
+ """
+ # undo all randomizations
+ self.restore_default_domain()
+
+ # normal env reset
+ ret = super().reset()
+
+ # save the original env parameters
+ self.save_default_domain()
+
+ # reset counter for doing domain randomization at a particular frequency
+ self.step_counter = 0
+
+ # update sims
+ for modder in self.modders:
+ modder.update_sim(self.env.sim)
+
+ if self.randomize_on_reset:
+ # domain randomize + regenerate observation
+ self.randomize_domain()
+ ret = self.env._get_observations()
+
+ return ret
+
+ def step(self, action):
+ """
+ Extends vanilla step() function call to accommodate domain randomization
+
+ Returns:
+ 4-tuple:
+
+ - (OrderedDict) observations from the environment
+ - (float) reward from the environment
+ - (bool) whether the current episode is completed or not
+ - (dict) misc information
+ """
+ # Step the internal randomization state
+ self.step_randomization()
+
+ return super().step(action)
+
+ def step_randomization(self):
+ """
+ Steps the internal randomization state
+ """
+ # functionality for randomizing at a particular frequency
+ if self.randomize_every_n_steps > 0:
+ if self.step_counter % self.randomize_every_n_steps == 0:
+ self.randomize_domain()
+ self.step_counter += 1
+
+ def randomize_domain(self):
+ """
+ Runs domain randomization over the environment.
+ """
+ for modder in self.modders:
+ modder.randomize()
+
+ def save_default_domain(self):
+ """
+ Saves the current simulation model parameters so
+ that they can be restored later.
+ """
+ for modder in self.modders:
+ modder.save_defaults()
+
+ def restore_default_domain(self):
+ """
+ Restores the simulation model parameters saved
+ in the last call to @save_default_domain.
+ """
+ for modder in self.modders:
+ modder.restore_defaults()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/wrappers/gym_wrapper.py b/phantom/submodules/phantom-robosuite/robosuite/wrappers/gym_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..329cdaf2ad2b6b691a79e4e0386f8ce24bee7cc4
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/wrappers/gym_wrapper.py
@@ -0,0 +1,134 @@
+"""
+This file implements a wrapper for facilitating compatibility with OpenAI gym.
+This is useful when using these environments with code that assumes a gym-like
+interface.
+"""
+
+import numpy as np
+import gymnasium as gym
+from gymnasium import spaces, Env
+
+from robosuite.wrappers import Wrapper
+
+
+class GymWrapper(Wrapper, gym.Env):
+ metadata = None
+ render_mode = None
+ """
+ Initializes the Gym wrapper. Mimics many of the required functionalities of the Wrapper class
+ found in the gym.core module
+
+ Args:
+ env (MujocoEnv): The environment to wrap.
+ keys (None or list of str): If provided, each observation will
+ consist of concatenated keys from the wrapped environment's
+ observation dictionary. Defaults to proprio-state and object-state.
+
+ Raises:
+ AssertionError: [Object observations must be enabled if no keys]
+ """
+
+ def __init__(self, env, keys=None):
+ # Run super method
+ super().__init__(env=env)
+ # Create name for gym
+ robots = "".join([type(robot.robot_model).__name__ for robot in self.env.robots])
+ self.name = robots + "_" + type(self.env).__name__
+
+ # Get reward range
+ self.reward_range = (0, self.env.reward_scale)
+
+ if keys is None:
+ keys = []
+ # Add object obs if requested
+ if self.env.use_object_obs:
+ keys += ["object-state"]
+ # Add image obs if requested
+ if self.env.use_camera_obs:
+ keys += [f"{cam_name}_image" for cam_name in self.env.camera_names]
+ # Iterate over all robots to add to state
+ for idx in range(len(self.env.robots)):
+ keys += ["robot{}_proprio-state".format(idx)]
+ self.keys = keys
+
+ # Gym specific attributes
+ self.env.spec = None
+
+ # set up observation and action spaces
+ obs = self.env.reset()
+ self.modality_dims = {key: obs[key].shape for key in self.keys}
+ flat_ob = self._flatten_obs(obs)
+ self.obs_dim = flat_ob.size
+ high = np.inf * np.ones(self.obs_dim)
+ low = -high
+ self.observation_space = spaces.Box(low, high)
+ low, high = self.env.action_spec
+ self.action_space = spaces.Box(low, high)
+
+ def _flatten_obs(self, obs_dict, verbose=False):
+ """
+ Filters keys of interest out and concatenate the information.
+
+ Args:
+ obs_dict (OrderedDict): ordered dictionary of observations
+ verbose (bool): Whether to print out to console as observation keys are processed
+
+ Returns:
+ np.array: observations flattened into a 1d array
+ """
+ ob_lst = []
+ for key in self.keys:
+ if key in obs_dict:
+ if verbose:
+ print("adding key: {}".format(key))
+ ob_lst.append(np.array(obs_dict[key]).flatten())
+ return np.concatenate(ob_lst)
+
+ def reset(self, seed=None, options=None):
+ """
+ Extends env reset method to return flattened observation instead of normal OrderedDict and optionally resets seed
+
+ Returns:
+ np.array: Flattened environment observation space after reset occurs
+ """
+ if seed is not None:
+ if isinstance(seed, int):
+ np.random.seed(seed)
+ else:
+ raise TypeError("Seed must be an integer type!")
+ ob_dict = self.env.reset()
+ return self._flatten_obs(ob_dict), {}
+
+ def step(self, action):
+ """
+ Extends vanilla step() function call to return flattened observation instead of normal OrderedDict.
+
+ Args:
+ action (np.array): Action to take in environment
+
+ Returns:
+ 4-tuple:
+
+ - (np.array) flattened observations from the environment
+ - (float) reward from the environment
+ - (bool) episode ending after reaching an env terminal state
+ - (bool) episode ending after an externally defined condition
+ - (dict) misc information
+ """
+ ob_dict, reward, terminated, info = self.env.step(action)
+ return self._flatten_obs(ob_dict), reward, terminated, False, info
+
+ def compute_reward(self, achieved_goal, desired_goal, info):
+ """
+ Dummy function to be compatible with gym interface that simply returns environment reward
+
+ Args:
+ achieved_goal: [NOT USED]
+ desired_goal: [NOT USED]
+ info: [NOT USED]
+
+ Returns:
+ float: environment reward
+ """
+ # Dummy args used to mimic Wrapper interface
+ return self.env.reward()
diff --git a/phantom/submodules/phantom-robosuite/robosuite/wrappers/visualization_wrapper.py b/phantom/submodules/phantom-robosuite/robosuite/wrappers/visualization_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..64b847af6ee67acbe61353340bfbf06f2cb09c1e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/wrappers/visualization_wrapper.py
@@ -0,0 +1,186 @@
+"""
+This file implements a wrapper for visualizing important sites in a given environment.
+
+By default, this visualizes all sites possible for the environment. Visualization options
+for a given environment can be found by calling `get_visualization_settings()`, and can
+be set individually by calling `set_visualization_setting(setting, visible)`.
+"""
+import xml.etree.ElementTree as ET
+from copy import deepcopy
+
+import numpy as np
+
+from robosuite.utils.mjcf_utils import new_body, new_geom, new_site
+from robosuite.wrappers import Wrapper
+
+DEFAULT_INDICATOR_SITE_CONFIG = {
+ "type": "sphere",
+ "size": [0.03],
+ "rgba": [1, 0, 0, 0.5],
+}
+
+
+class VisualizationWrapper(Wrapper):
+ def __init__(self, env, indicator_configs=None):
+ """
+ Initializes the data collection wrapper. Note that this automatically conducts a (hard) reset initially to make
+ sure indicators are properly added to the sim model.
+
+ Args:
+ env (MujocoEnv): The environment to visualize
+
+ indicator_configs (None or str or dict or list): Configurations to use for indicator objects.
+
+ If None, no indicator objects will be used
+
+ If a string, this should be `'default'`, which corresponds to single default spherical indicator
+
+ If a dict, should specify a single indicator object config
+
+ If a list, should specify specific indicator object configs to use for multiple indicators (which in
+ turn can either be `'default'` or a dict)
+
+ As each indicator object is essentially a site element, each dict should map site attribute keywords to
+ values. Note that, at the very minimum, the `'name'` attribute MUST be specified for each indicator. See
+ http://www.mujoco.org/book/XMLreference.html#site for specific site attributes that can be specified.
+ """
+ super().__init__(env)
+
+ # Make sure that the environment is NOT using segmentation sensors, since we cannot use segmentation masks
+ # with visualization sites simultaneously
+ assert all(
+ seg is None for seg in env.camera_segmentations
+ ), "Cannot use camera segmentations with visualization wrapper!"
+
+ # Standardize indicator configs
+ self.indicator_configs = None
+ if indicator_configs is not None:
+ self.indicator_configs = []
+ if type(indicator_configs) in {str, dict}:
+ indicator_configs = [indicator_configs]
+ for i, indicator_config in enumerate(indicator_configs):
+ if indicator_config == "default":
+ indicator_config = deepcopy(DEFAULT_INDICATOR_SITE_CONFIG)
+ indicator_config["name"] = f"indicator{i}"
+ # Make sure name attribute is specified
+ assert "name" in indicator_config, "Name must be specified for all indicator object configurations!"
+ # Add this configuration to the internal array
+ self.indicator_configs.append(indicator_config)
+
+ # Create internal dict to store visualization settings (set to True by default)
+ self._vis_settings = {vis: True for vis in self.env._visualizations}
+
+ # Add the post-processor to make sure indicator objects get added to model before it's actually loaded in sim
+ self.env.set_xml_processor(processor=self._add_indicators_to_model)
+
+ # Conduct a (hard) reset to make sure visualization changes propagate
+ reset_mode = self.env.hard_reset
+ self.env.hard_reset = True
+ self.reset()
+ self.env.hard_reset = reset_mode
+
+ def get_indicator_names(self):
+ """
+ Gets all indicator object names for this environment.
+
+ Returns:
+ list: Indicator names for this environment.
+ """
+ return (
+ [ind_config["name"] for ind_config in self.indicator_configs] if self.indicator_configs is not None else []
+ )
+
+ def set_indicator_pos(self, indicator, pos):
+ """
+ Sets the specified @indicator to the desired position @pos
+
+ Args:
+ indicator (str): Name of the indicator to set
+ pos (3-array): (x, y, z) Cartesian world coordinates to set the specified indicator to
+ """
+ # Make sure indicator is valid
+ indicator_names = set(self.get_indicator_names())
+ assert indicator in indicator_names, "Invalid indicator name specified. Valid options are {}, got {}".format(
+ indicator_names, indicator
+ )
+ # Set the specified indicator
+ self.env.sim.model.body_pos[self.env.sim.model.body_name2id(indicator + "_body")] = np.array(pos)
+
+ def get_visualization_settings(self):
+ """
+ Gets all settings for visualizing this environment
+
+ Returns:
+ list: Visualization keywords for this environment.
+ """
+ return self._vis_settings.keys()
+
+ def set_visualization_setting(self, setting, visible):
+ """
+ Sets the specified @setting to have visibility = @visible.
+
+ Args:
+ setting (str): Visualization keyword to set
+ visible (bool): True if setting should be visualized.
+ """
+ assert (
+ setting in self._vis_settings
+ ), "Invalid visualization setting specified. Valid options are {}, got {}".format(
+ self._vis_settings.keys(), setting
+ )
+ self._vis_settings[setting] = visible
+
+ def reset(self):
+ """
+ Extends vanilla reset() function call to accommodate visualization
+
+ Returns:
+ OrderedDict: Environment observation space after reset occurs
+ """
+ ret = super().reset()
+ # Update any visualization
+ self.env.visualize(vis_settings=self._vis_settings)
+ return ret
+
+ def step(self, action):
+ """
+ Extends vanilla step() function call to accommodate visualization
+
+ Args:
+ action (np.array): Action to take in environment
+
+ Returns:
+ 4-tuple:
+
+ - (OrderedDict) observations from the environment
+ - (float) reward from the environment
+ - (bool) whether the current episode is completed or not
+ - (dict) misc information
+ """
+ ret = super().step(action)
+
+ # Update any visualization
+ self.env.visualize(vis_settings=self._vis_settings)
+
+ return ret
+
+ def _add_indicators_to_model(self, xml):
+ """
+ Adds indicators to the mujoco simulation model
+
+ Args:
+ xml (string): MJCF model in xml format, for the current simulation to be loaded
+ """
+ if self.indicator_configs is not None:
+ root = ET.fromstring(xml)
+ worldbody = root.find("worldbody")
+
+ for indicator_config in self.indicator_configs:
+ config = deepcopy(indicator_config)
+ indicator_body = new_body(name=config["name"] + "_body", pos=config.pop("pos", (0, 0, 0)))
+ indicator_body.append(new_site(**config))
+ worldbody.append(indicator_body)
+
+ xml = ET.tostring(root, encoding="utf8").decode("utf8")
+
+ return xml
diff --git a/phantom/submodules/phantom-robosuite/robosuite/wrappers/wrapper.py b/phantom/submodules/phantom-robosuite/robosuite/wrappers/wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c922dfdfc0fcd58109e4a36c7815dea4524cb8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/robosuite/wrappers/wrapper.py
@@ -0,0 +1,135 @@
+"""
+This file contains the base wrapper class for Mujoco environments.
+Wrappers are useful for data collection and logging. Highly recommended.
+"""
+
+
+class Wrapper:
+ """
+ Base class for all wrappers in robosuite.
+
+ Args:
+ env (MujocoEnv): The environment to wrap.
+ """
+
+ def __init__(self, env):
+ self.env = env
+
+ @classmethod
+ def class_name(cls):
+ return cls.__name__
+
+ def _warn_double_wrap(self):
+ """
+ Utility function that checks if we're accidentally trying to double wrap an env
+
+ Raises:
+ Exception: [Double wrapping env]
+ """
+ env = self.env
+ while True:
+ if isinstance(env, Wrapper):
+ if env.class_name() == self.class_name():
+ raise Exception("Attempted to double wrap with Wrapper: {}".format(self.__class__.__name__))
+ env = env.env
+ else:
+ break
+
+ def step(self, action):
+ """
+ By default, run the normal environment step() function
+
+ Args:
+ action (np.array): action to take in environment
+
+ Returns:
+ 4-tuple:
+
+ - (OrderedDict) observations from the environment
+ - (float) reward from the environment
+ - (bool) whether the current episode is completed or not
+ - (dict) misc information
+ """
+ return self.env.step(action)
+
+ def reset(self):
+ """
+ By default, run the normal environment reset() function
+
+ Returns:
+ OrderedDict: Environment observation space after reset occurs
+ """
+ return self.env.reset()
+
+ def render(self, **kwargs):
+ """
+ By default, run the normal environment render() function
+
+ Args:
+ **kwargs (dict): Any args to pass to environment render function
+ """
+ return self.env.render(**kwargs)
+
+ def observation_spec(self):
+ """
+ By default, grabs the normal environment observation_spec
+
+ Returns:
+ OrderedDict: Observations from the environment
+ """
+ return self.env.observation_spec()
+
+ @property
+ def action_spec(self):
+ """
+ By default, grabs the normal environment action_spec
+
+ Returns:
+ 2-tuple:
+
+ - (np.array) minimum (low) action values
+ - (np.array) maximum (high) action values
+ """
+ return self.env.action_spec
+
+ @property
+ def action_dim(self):
+ """
+ By default, grabs the normal environment action_dim
+
+ Returns:
+ int: Action space dimension
+ """
+ return self.env.dof
+
+ @property
+ def unwrapped(self):
+ """
+ Grabs unwrapped environment
+
+ Returns:
+ env (MujocoEnv): Unwrapped environment
+ """
+ if hasattr(self.env, "unwrapped"):
+ return self.env.unwrapped
+ else:
+ return self.env
+
+ # this method is a fallback option on any methods the original env might support
+ def __getattr__(self, attr):
+ # using getattr ensures that both __getattribute__ and __getattr__ (fallback) get called
+ # (see https://stackoverflow.com/questions/3278077/difference-between-getattr-vs-getattribute)
+ orig_attr = getattr(self.env, attr)
+ if callable(orig_attr):
+
+ def hooked(*args, **kwargs):
+ result = orig_attr(*args, **kwargs)
+ # prevent wrapped_class from becoming unwrapped
+ # NOTE: had to use "is" to prevent errors when returning numpy arrays from a wrapped method
+ if result is self.env:
+ return self
+ return result
+
+ return hooked
+ else:
+ return orig_attr
diff --git a/phantom/submodules/phantom-robosuite/setup.py b/phantom/submodules/phantom-robosuite/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..78b74e6955f406b3023459136090d923bd9949ab
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/setup.py
@@ -0,0 +1,37 @@
+# read the contents of your README file
+from os import path
+
+from setuptools import find_packages, setup
+
+this_directory = path.abspath(path.dirname(__file__))
+with open(path.join(this_directory, "README.md"), encoding="utf-8") as f:
+ lines = f.readlines()
+
+# remove images from README
+lines = [x for x in lines if ".png" not in x]
+long_description = "".join(lines)
+
+setup(
+ name="robosuite",
+ packages=[package for package in find_packages() if package.startswith("robosuite")],
+ install_requires=[
+ "numpy>=1.13.3",
+ "numba>=0.49.1",
+ "scipy>=1.2.3",
+ "mujoco>=2.3.0",
+ "Pillow",
+ "opencv-python",
+ "pynput",
+ "termcolor",
+ ],
+ eager_resources=["*"],
+ include_package_data=True,
+ python_requires=">=3",
+ description="robosuite: A Modular Simulation Framework and Benchmark for Robot Learning",
+ author="Yuke Zhu",
+ url="https://github.com/ARISE-Initiative/robosuite",
+ author_email="yukez@cs.utexas.edu",
+ version="1.4.1",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+)
diff --git a/phantom/submodules/phantom-robosuite/tests/test_controllers/test_all_controllers.py b/phantom/submodules/phantom-robosuite/tests/test_controllers/test_all_controllers.py
new file mode 100644
index 0000000000000000000000000000000000000000..356057b0d7aa4f748067fbfe9c8ac801c6d8003e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_controllers/test_all_controllers.py
@@ -0,0 +1,155 @@
+"""
+Test all controllers on the Lift task with Sawyer robot environment as a test case.
+
+The following controllers are tested:
+Operational Space Control - Position & Orientation
+Operational Space Control - Position only
+Inverse Kinematics - Position & Orientation
+Joint Impedance
+Joint Velocity
+Joint Torque
+
+This (non-exhaustive) test script checks for qualitative irregularities in controller behavior.
+However, this testing module also checks for action space correctness and dimensionality.
+For every controller action space, runs through each dimension and executes a perturbation "test_value" from its
+neutral (stationary) value for a certain amount of time "steps_per_action", and then returns to all neutral values
+for time "steps_per_rest" before proceeding with the next action dim.
+
+ E.g.: Given that the expected action space of the Pos / Ori (OSC_POSE) controller (without a gripper) is
+ (dx, dy, dz, ax, ay, az), the testing sequence of actions over time will be:
+
+ ***START OF TEST***
+ ( dx, 0, 0, 0, 0, 0, grip) <-- Translation in x-direction for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, dy, 0, 0, 0, 0, grip) <-- Translation in y-direction for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, 0, dz, 0, 0, 0, grip) <-- Translation in z-direction for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, 0, 0, a, 0, 0, grip) <-- Rotation about x axis for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, 0, 0, 0, a, 0, grip) <-- Rotation about y axis for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ( 0, 0, 0, 0, 0, a, grip) <-- Rotation about z axis for 'steps_per_action' steps
+ ( 0, 0, 0, 0, 0, 0, grip) <-- No movement (pause) for 'steps_per_rest' steps
+ ***END OF TEST***
+
+ Thus the OSC_POSE controller should be expected to sequentially move linearly in the x direction first,
+ then the y direction, then the z direction, and then begin sequentially rotating about its x-axis,
+ then y-axis, then z-axis.
+
+Please reference the controller README in the robosuite/controllers directory for an overview of each controller.
+Controllers are expected to behave in a generally controlled manner, according to their control space.
+ E.g.: the Pos / Ori controller should be expected to move linearly in the x direction first, then the y direction,
+ then the z direction, and then begin rotating about its x-axis, then y-axis, then z-axis.
+
+As this is strictly a qualitative set of tests, it is up to the developer / user to examine for specific irregularities.
+However, the expected qualitative behavior is described below for each controller:
+
+* OSC_POSE: Gripper moves sequentially and linearly in x, y, z direction, then sequentially rotates in x-axis,
+ y-axis, z-axis, relative to the global coordinate frame
+* OSC_POSITION: Gripper moves sequentially and linearly in x, y, z direction, relative to the global coordinate frame
+* IK_POSE: Gripper moves sequentially and linearly in x, y, z direction, then sequentially rotates in x-axis, y-axis,
+ z-axis, relative to the local robot end effector frame
+* JOINT_POSITION: Robot Joints move sequentially in a controlled fashion
+* JOINT_VELOCITY: Robot Joints move sequentially in a controlled fashion
+* JOINT_TORQUE: Unlike other controllers, joint torque controller is expected to act rather lethargic, as the
+ "controller" is really just a wrapper for direct torque control of the mujoco actuators. Therefore, a
+ "neutral" value of 0 torque will not guarantee a stable robot when it has non-zero velocity!
+
+Note that by default, there is no rendering. Rendering can be enabled by setting the --render flag when calling this
+test script.
+
+"""
+import argparse
+
+import numpy as np
+
+import robosuite as suite
+import robosuite.utils.transform_utils as T
+from robosuite import load_controller_config
+
+# Arguments for this test script
+parser = argparse.ArgumentParser()
+parser.add_argument("--render", action="store_true", help="Whether to render this test or not for visual validation")
+args = parser.parse_args()
+
+# Define the controllers to use (action_dim, num_test_steps, test_value)
+controllers = {
+ "OSC_POSE": [7, 6, 0.1],
+ "OSC_POSITION": [4, 3, 0.1],
+ "IK_POSE": [7, 6, 0.01],
+ "JOINT_POSITION": [8, 7, 0.2],
+ "JOINT_VELOCITY": [8, 7, -0.1],
+ "JOINT_TORQUE": [8, 7, 0.25],
+}
+
+# Define the number of timesteps to use per controller action as well as timesteps in between actions
+steps_per_action = 50
+steps_per_rest = 25
+
+
+def test_all_controllers():
+ for controller_name in controllers.keys():
+ # Define variables for each controller test
+ action_dim = controllers[controller_name][0]
+ num_test_steps = controllers[controller_name][1]
+ test_value = controllers[controller_name][2]
+ neutral = np.zeros(action_dim)
+
+ # Define controller path to load
+ controller_config = load_controller_config(default_controller=controller_name)
+
+ # Now, create a test env for testing the controller on
+ env = suite.make(
+ "Lift",
+ robots="Sawyer",
+ has_renderer=args.render, # use on-screen renderer for visual validation only if requested
+ has_offscreen_renderer=False,
+ use_camera_obs=False,
+ horizon=(steps_per_action + steps_per_rest) * num_test_steps,
+ controller_configs=controller_config,
+ )
+ print("Testing controller: {}...".format(controller_name))
+
+ env.reset()
+ # If rendering, set controller to front view to get best angle for viewing robot movements
+ if args.render:
+ env.viewer.set_camera(camera_id=0)
+
+ # get action range
+ action_min, action_max = env.action_spec
+ assert action_min.shape == action_max.shape
+ assert action_min.shape[0] == action_dim, "Expected {}, got {}".format(action_dim, action_min.shape[0])
+
+ # Keep track of done variable to know when to break loop
+ count = 0
+ # Loop through controller space
+ while count < num_test_steps:
+ action = neutral.copy()
+ for i in range(steps_per_action):
+ if controller_name in {"IK_POSE", "OSC_POSE"} and count > 2:
+ # Set this value to be the angle and set appropriate axis
+ vec = np.zeros(3)
+ vec[count - 3] = test_value
+ action[3:6] = vec
+ else:
+ action[count] = test_value
+ env.step(action)
+ if args.render:
+ env.render()
+ for i in range(steps_per_rest):
+ env.step(neutral)
+ if args.render:
+ env.render()
+ count += 1
+
+ # Shut down this env before starting the next test
+ env.close()
+
+ # Tests passed!
+ print("All controller tests completed.")
+
+
+if __name__ == "__main__":
+
+ test_all_controllers()
diff --git a/phantom/submodules/phantom-robosuite/tests/test_controllers/test_linear_interpolator.py b/phantom/submodules/phantom-robosuite/tests/test_controllers/test_linear_interpolator.py
new file mode 100644
index 0000000000000000000000000000000000000000..d12729562c2bb9538dd3cc9b53b39b250a5622a8
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_controllers/test_linear_interpolator.py
@@ -0,0 +1,195 @@
+"""
+Test the linear interpolator on the Lift task with Sawyer arm environment as a test case.
+
+The linear interpolator is meant to increase the stability and overall safety of a robot arm's trajectory when reaching
+a setpoint, "ramping up" the actual action command sent to a given controller from zero to the actual inputted action
+over a fraction of the timesteps in betwteen each high-level input action (the "ramp ratio"). As a result, the
+resulting trajectory should be smoother, proportional to the interpolator's ramp ratio setting.
+
+This test verifies that the linear interpolator works correctly on both the IK and OSC controller for both position and
+orientation, and proceeds as follows:
+
+ 1. Given a constant delta position action, and with the interpolator disabled, we will measure the sum of absolute
+ changes in joint torques between individual simulation timesteps
+
+ 2. We will repeat Step 1, but this time with the interpolator enabled and with a ramp ratio of 1.0 (max value)
+
+ 3. We expect the interpolated trajectories to experience a smaller overall magnitude of changes in torques, due to
+ the setpoints between controller timesteps being smoothed out over the ramp ratio.
+
+Note: As this is a qualitative test, it is up to the user to evaluate the output and determine the expected behavior of
+the tested controllers.
+"""
+
+import argparse
+import json
+import os
+
+import numpy as np
+
+import robosuite as suite
+import robosuite.utils.transform_utils as T
+
+# Define the threshold locations, delta values, and ratio #
+
+# Translation trajectory
+pos_y_threshold = 0.1
+delta_pos_y = 0.01
+pos_action_osc = [0, delta_pos_y * 40, 0]
+pos_action_ik = [0, delta_pos_y, 0]
+
+# Rotation trajectory
+rot_r_threshold = np.pi / 2
+delta_rot_r = 0.01
+rot_action_osc = [delta_rot_r * 40, 0, 0]
+rot_action_ik = [delta_rot_r * 5, 0, 0]
+
+# Concatenated thresholds and corresponding indexes (y = 1 in x,y,z; roll = 0 in r,p,y)
+thresholds = [pos_y_threshold, rot_r_threshold]
+indexes = [1, 0]
+
+# Threshold ratio
+min_ratio = 1.10
+
+# Define arguments for this test
+parser = argparse.ArgumentParser()
+parser.add_argument("--render", action="store_true", help="Whether to render tests or run headless")
+args = parser.parse_args()
+
+# Setup printing options for numbers
+np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
+
+
+# function to run the actual sim in order to receive summed absolute delta torques
+def step(env, action, current_torques):
+ env.timestep += 1
+ policy_step = True
+ summed_abs_delta_torques = np.zeros(7)
+
+ for i in range(int(env.control_timestep / env.model_timestep)):
+ env.sim.forward()
+ env._pre_action(action, policy_step)
+ last_torques = current_torques
+ current_torques = env.robots[0].torques
+ summed_abs_delta_torques += np.abs(current_torques - last_torques)
+ env.sim.step()
+ policy_step = False
+
+ env.cur_time += env.control_timestep
+ out = env._post_action(action)
+ return out, summed_abs_delta_torques, current_torques
+
+
+# Running the actual test #
+def test_linear_interpolator():
+
+ for controller_name in ["IK_POSE", "OSC_POSE"]:
+
+ for traj in ["pos", "ori"]:
+
+ # Define counter to increment timesteps and torques for each trajectory
+ timesteps = [0, 0]
+ summed_abs_delta_torques = [np.zeros(7), np.zeros(7)]
+
+ for interpolator in [None, "linear"]:
+ # Define numpy seed so we guarantee consistent starting pos / ori for each trajectory
+ np.random.seed(3)
+
+ # Define controller path to load
+ controller_path = os.path.join(
+ os.path.dirname(__file__),
+ "../../robosuite",
+ "controllers/config/{}.json".format(controller_name.lower()),
+ )
+ with open(controller_path) as f:
+ controller_config = json.load(f)
+ controller_config["interpolation"] = interpolator
+ controller_config["ramp_ratio"] = 1.0
+
+ # Now, create a test env for testing the controller on
+ env = suite.make(
+ "Lift",
+ robots="Sawyer",
+ has_renderer=args.render, # by default, don't use on-screen renderer for visual validation
+ has_offscreen_renderer=False,
+ use_camera_obs=False,
+ horizon=10000,
+ control_freq=20,
+ controller_configs=controller_config,
+ )
+
+ # Reset the environment
+ env.reset()
+
+ # Hardcode the starting position for sawyer
+ init_qpos = [-0.5538, -0.8208, 0.4155, 1.8409, -0.4955, 0.6482, 1.9628]
+ env.robots[0].set_robot_joint_positions(init_qpos)
+ env.robots[0].controller.update_initial_joints(init_qpos)
+ env.robots[0].controller.reset_goal()
+
+ # Notify user a new trajectory is beginning
+ print(
+ "\nTesting controller {} with trajectory {} and interpolator={}...".format(
+ controller_name, traj, interpolator
+ )
+ )
+
+ # If rendering, set controller to front view to get best angle for viewing robot movements
+ if args.render:
+ env.viewer.set_camera(camera_id=0)
+
+ # Keep track of state of robot eef (pos, ori (euler)) and torques
+ current_torques = np.zeros(7)
+ initial_state = [env.robots[0]._hand_pos, T.mat2quat(env.robots[0]._hand_orn)]
+ dstate = [
+ env.robots[0]._hand_pos - initial_state[0],
+ T.mat2euler(T.quat2mat(T.quat_distance(T.mat2quat(env.robots[0]._hand_orn), initial_state[1]))),
+ ]
+
+ # Define the uniform trajectory action
+ if traj == "pos":
+ pos_act = pos_action_ik if controller_name == "IK_POSE" else pos_action_osc
+ rot_act = np.zeros(3)
+ else:
+ pos_act = np.zeros(3)
+ rot_act = rot_action_ik if controller_name == "IK_POSE" else rot_action_osc
+
+ # Compose the action
+ action = np.concatenate([pos_act, rot_act, [0]])
+
+ # Determine which trajectory we're executing
+ k = 0 if traj == "pos" else 1
+ j = 0 if not interpolator else 1
+
+ # Run trajectory until the threshold condition is met
+ while abs(dstate[k][indexes[k]]) < abs(thresholds[k]):
+ _, summed_torques, current_torques = step(env, action, current_torques)
+ if args.render:
+ env.render()
+
+ # Update torques, timestep count, and state
+ summed_abs_delta_torques[j] += summed_torques
+ timesteps[j] += 1
+ dstate = [
+ env.robots[0]._hand_pos - initial_state[0],
+ T.mat2euler(T.quat2mat(T.quat_distance(T.mat2quat(env.robots[0]._hand_orn), initial_state[1]))),
+ ]
+
+ # When finished, print out the timestep results
+ print(
+ "Completed trajectory. Avg per-step absolute delta torques: {}".format(
+ summed_abs_delta_torques[j] / timesteps[j]
+ )
+ )
+
+ # Shut down this env before starting the next test
+ env.close()
+
+ # Tests completed!
+ print()
+ print("-" * 80)
+ print("All linear interpolator testing completed.\n")
+
+
+if __name__ == "__main__":
+ test_linear_interpolator()
diff --git a/phantom/submodules/phantom-robosuite/tests/test_controllers/test_variable_impedance.py b/phantom/submodules/phantom-robosuite/tests/test_controllers/test_variable_impedance.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dc80c41de0b22526786987c69501bee1ed38596
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_controllers/test_variable_impedance.py
@@ -0,0 +1,189 @@
+"""
+Test the variable impedance feature of impedance-based controllers (OSC, Joint Position) on the Lift task with
+Sawyer arm environment as a test case.
+
+The variable impedance feature allows per-action fine-grained control over the specific impedance gains when executing
+impedance control (namely, "kp" and "damping" ratios). This allows a given controller to execute more complex and
+potentially interactive trajectories by varying the net impedance of the controlled actuators over time.
+
+This (qualitative) test verifies that the variable impedance works correctly on both the OSC Pose / Position and
+Joint Position controllers, and proceeds as follows:
+
+ 1. Given a constant delta position action, and with the the kp values set to critically-damped, we will ramp up
+ the kp values to its max and then ramp down the values. We qualitatively expect the arm to accelerate as the kp
+ values are ramped, and then slow down as they are decreased.
+
+ 2. The environment will then be reset. Given a constant delta position action, and with kp values set to its
+ default value, we will ramp up the damping values to its max and then ramp down the values. We qualitatively
+ expect the arm to slow down as the damping values are ramped, and then increase in speed as they are decreased.
+
+ 3. We will repeat Step 1 and 2 for each of the tested controllers.
+
+Periodic prijntouts should verify the above patterns; conversely, running the script with the "--render" argument will
+render the trajectories to allow for visual analysis of gains
+"""
+
+import argparse
+import json
+import os
+
+import numpy as np
+
+import robosuite as suite
+
+# Define the rate of change when sweeping through kp / damping values
+num_timesteps_per_change = 10
+percent_increase = 0.05
+
+# Define delta values for trajectory
+d = 0.05
+
+# Define default values for fixing one of the two gains
+kp_default = 150
+damping_default = 1 # critically damped
+
+# Define arguments for this test
+parser = argparse.ArgumentParser()
+parser.add_argument("--render", action="store_true", help="Whether to render tests or run headless")
+args = parser.parse_args()
+
+
+# Running the actual test #
+def test_variable_impedance():
+
+ for controller_name in ["OSC_POSE", "OSC_POSITION", "JOINT_POSITION"]:
+
+ # Define numpy seed so we guarantee consistent starting pos / ori for each trajectory
+ np.random.seed(3)
+
+ # Define controller path to load
+ controller_path = os.path.join(
+ os.path.dirname(__file__), "../../robosuite", "controllers/config/{}.json".format(controller_name.lower())
+ )
+
+ # Load the controller
+ with open(controller_path) as f:
+ controller_config = json.load(f)
+
+ # Manually edit impedance settings
+ controller_config["impedance_mode"] = "variable"
+ controller_config["kp_limits"] = [0, 300]
+ controller_config["damping_limits"] = [0, 10]
+
+ # Now, create a test env for testing the controller on
+ env = suite.make(
+ "Lift",
+ robots="Sawyer",
+ has_renderer=args.render, # by default, don't use on-screen renderer for visual validation
+ has_offscreen_renderer=False,
+ use_camera_obs=False,
+ horizon=10000,
+ control_freq=20,
+ controller_configs=controller_config,
+ )
+
+ # Setup printing options for numbers
+ np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
+
+ # Get limits on kp and damping values
+ # Define control dim. Note that this is not the action space, but internal dimensionality of gains
+ control_dim = 6 if "OSC" in controller_name else 7
+ low, high = env.action_spec
+ damping_low, kp_low = low[:control_dim], low[control_dim : 2 * control_dim]
+ damping_high, kp_high = high[:control_dim], high[control_dim : 2 * control_dim]
+ damping_range = damping_high - damping_low
+ kp_range = kp_high - kp_low
+
+ # Get delta values for trajectory
+ if controller_name == "OSC_POSE":
+ delta = np.array([0, d, 0, 0, 0, 0])
+ elif controller_name == "OSC_POSITION":
+ delta = np.array([0, d, 0])
+ else: # JOINT_POSITION
+ delta = np.array([d, 0, 0, 0, 0, 0, 0])
+
+ # Get total number of steps each test should take (num steps ramping up + num steps ramping down)
+ total_steps = num_timesteps_per_change / percent_increase * 2
+
+ # Run a test for both kp and damping
+ gains = ["kp", "damping"]
+
+ for gain in gains:
+
+ # Reset the environment
+ env.reset()
+
+ # Hardcode the starting position for sawyer
+ init_qpos = [-0.5538, -0.8208, 0.4155, 1.8409, -0.4955, 0.6482, 1.9628]
+ env.robots[0].set_robot_joint_positions(init_qpos)
+ env.robots[0].controller.update_initial_joints(init_qpos)
+
+ # Notify user a new test is beginning
+ print("\nTesting controller {} while sweeping {}...".format(controller_name, gain))
+
+ # If rendering, set controller to front view to get best angle for viewing robot movements
+ if args.render:
+ env.viewer.set_camera(camera_id=0)
+
+ # Keep track of relative changes in robot eef position
+ last_pos = env.robots[0]._hand_pos
+
+ # Initialize gains
+ if gain == "kp":
+ kp = kp_low
+ damping = damping_default * np.ones(control_dim)
+ gain_val = kp # alias for kp
+ gain_range = kp_range
+ else: # "damping"
+ kp = kp_default * np.ones(control_dim)
+ damping = damping_low
+ gain_val = damping # alias for damping
+ gain_range = damping_range
+
+ # Initialize counters
+ i = 0
+ sign = 1.0 # Whether to increase or decrease gain
+
+ # Run trajectory until the threshold condition is met
+ while i < total_steps:
+ # Create action (damping, kp, traj, gripper)
+ action = np.concatenate([damping, kp, sign * delta, [0]])
+
+ # Take an environment step
+ env.step(action)
+ if args.render:
+ env.render()
+
+ # Update the current change in state
+ cur_pos = env.robots[0]._hand_pos
+
+ # If we're at the end of the increase, switch direction of traj and gain changes
+ if i == int(num_timesteps_per_change / percent_increase):
+ sign *= -1.0
+
+ # Update gain if this is a changing step
+ if i % num_timesteps_per_change == 0:
+ # Compare delta, print out to user, and update last_pos
+ delta_pos = np.linalg.norm(cur_pos - last_pos)
+ print(" Magnitude eef distance change with {} = {}: {:.5f}".format(gain, gain_val[0], delta_pos))
+ last_pos = cur_pos
+ # Update gain
+ gain_val += percent_increase * gain_range * sign
+
+ # Update timestep count
+ i += 1
+
+ # When finished, print out the timestep results
+ print("Completed trajectory.")
+
+ # Shut down this env before starting the next test
+ env.close()
+
+ # Tests completed!
+ print()
+ print("-" * 80)
+ print("All variable impedance testing completed.\n")
+
+
+if __name__ == "__main__":
+ test_variable_impedance()
diff --git a/phantom/submodules/phantom-robosuite/tests/test_environments/test_action_playback.py b/phantom/submodules/phantom-robosuite/tests/test_environments/test_action_playback.py
new file mode 100644
index 0000000000000000000000000000000000000000..96f256e0cd10fc22a8e923b2c8d8c4791680ed23
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_environments/test_action_playback.py
@@ -0,0 +1,76 @@
+"""
+Test script for recording a sequence of random actions and playing them back
+"""
+
+import argparse
+import json
+import os
+import random
+
+import h5py
+import numpy as np
+
+import robosuite
+from robosuite.controllers import load_controller_config
+
+
+def test_playback():
+ # set seeds
+ random.seed(0)
+ np.random.seed(0)
+
+ env = robosuite.make(
+ "Lift",
+ robots=["Panda"],
+ controller_configs=load_controller_config(default_controller="OSC_POSE"),
+ has_renderer=False,
+ has_offscreen_renderer=False,
+ ignore_done=True,
+ use_camera_obs=False,
+ reward_shaping=True,
+ control_freq=20,
+ )
+ env.reset()
+
+ # task instance
+ task_xml = env.sim.model.get_xml()
+ task_init_state = np.array(env.sim.get_state().flatten())
+
+ # trick for ensuring that we can play MuJoCo demonstrations back
+ # deterministically by using the recorded actions open loop
+ env.reset_from_xml_string(task_xml)
+ env.sim.reset()
+ env.sim.set_state_from_flattened(task_init_state)
+ env.sim.forward()
+
+ # random actions to play
+ n_actions = 100
+ actions = 0.1 * np.random.uniform(low=-1.0, high=1.0, size=(n_actions, env.action_spec[0].shape[0]))
+
+ # play actions
+ print("playing random actions...")
+ states = [task_init_state]
+ for i in range(n_actions):
+ env.step(actions[i])
+ states.append(np.array(env.sim.get_state().flatten()))
+
+ # try playback
+ print("attempting playback...")
+ env.reset()
+ env.reset_from_xml_string(task_xml)
+ env.sim.reset()
+ env.sim.set_state_from_flattened(task_init_state)
+ env.sim.forward()
+
+ for i in range(n_actions):
+ env.step(actions[i])
+ state_playback = env.sim.get_state().flatten()
+ assert np.all(np.equal(states[i + 1], state_playback))
+
+ env.close()
+ print("test passed!")
+
+
+if __name__ == "__main__":
+
+ test_playback()
diff --git a/phantom/submodules/phantom-robosuite/tests/test_environments/test_all_environments.py b/phantom/submodules/phantom-robosuite/tests/test_environments/test_all_environments.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbbe918653d248efe2909598f53c990829369fae
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_environments/test_all_environments.py
@@ -0,0 +1,88 @@
+"""
+Test all environments with random policies.
+
+This runs some basic sanity checks on the environment, namely, checking that:
+ - proprio-state exists in the obs, and is a flat array
+ - agentview_image exists and is of the correct shape
+ - no object-obs in state, because we are only using image observations
+
+Obviously, if an environment crashes during runtime, that is considered a failure as well.
+"""
+import numpy as np
+
+import robosuite as suite
+
+
+def test_all_environments():
+
+ envs = sorted(suite.ALL_ENVIRONMENTS)
+ for env_name in envs:
+ # Create config dict
+ env_config = {"env_name": env_name}
+ for robot_name in ("Panda", "Sawyer", "Baxter"):
+ # create an environment for learning on pixels
+ config = None
+ if "TwoArm" in env_name:
+ if robot_name == "Baxter":
+ robots = robot_name
+ config = "bimanual"
+ else:
+ robots = [robot_name, robot_name]
+ config = "single-arm-opposed"
+ # compile configuration specs
+ env_config["robots"] = robots
+ env_config["env_configuration"] = config
+ else:
+ if robot_name == "Baxter":
+ continue
+ env_config["robots"] = robot_name
+
+ # Notify user of which test we are currently on
+ print("Testing env: {} with robots {} with config {}...".format(env_name, env_config["robots"], config))
+
+ # Create environment
+ env = suite.make(
+ **env_config,
+ has_renderer=False, # no on-screen renderer
+ has_offscreen_renderer=True, # off-screen renderer is required for camera observations
+ ignore_done=True, # (optional) never terminates episode
+ use_camera_obs=True, # use camera observations
+ camera_heights=84, # set camera height
+ camera_widths=84, # set camera width
+ camera_names="agentview", # use "agentview" camera
+ use_object_obs=False, # no object feature when training on pixels
+ reward_shaping=True, # (optional) using a shaping reward
+ )
+
+ obs = env.reset()
+
+ # get action range
+ action_min, action_max = env.action_spec
+ assert action_min.shape == action_max.shape
+
+ # Get robot prefix
+ pr = env.robots[0].robot_model.naming_prefix
+
+ # run 10 random actions
+ for _ in range(10):
+
+ assert pr + "proprio-state" in obs
+ assert obs[pr + "proprio-state"].ndim == 1
+
+ assert "agentview_image" in obs
+ assert obs["agentview_image"].shape == (84, 84, 3)
+
+ assert "object-state" not in obs
+
+ action = np.random.uniform(action_min, action_max)
+ obs, reward, done, info = env.step(action)
+
+ env.close()
+
+ # Tests passed!
+ print("All environment tests passed successfully!")
+
+
+if __name__ == "__main__":
+
+ test_all_environments()
diff --git a/phantom/submodules/phantom-robosuite/tests/test_environments/test_camera_transforms.py b/phantom/submodules/phantom-robosuite/tests/test_environments/test_camera_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd3e8b6516c2f51edc883fcf83b5da0ad1c73ac6
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_environments/test_camera_transforms.py
@@ -0,0 +1,93 @@
+"""
+Test script for camera transforms. This test will read the ground-truth
+object state in the Lift environment, transform it into a pixel location
+in the camera frame, then transform it back to the world frame, and assert
+that the values are close.
+"""
+import random
+
+import numpy as np
+
+import robosuite
+import robosuite.utils.camera_utils as CU
+from robosuite.controllers import load_controller_config
+
+
+def test_camera_transforms():
+ # set seeds
+ random.seed(0)
+ np.random.seed(0)
+
+ camera_name = "agentview"
+ camera_height = 120
+ camera_width = 120
+ env = robosuite.make(
+ "Lift",
+ robots=["Panda"],
+ controller_configs=load_controller_config(default_controller="OSC_POSE"),
+ has_renderer=False,
+ has_offscreen_renderer=True,
+ ignore_done=True,
+ use_object_obs=True,
+ use_camera_obs=True,
+ camera_names=[camera_name],
+ camera_depths=[True],
+ camera_heights=[camera_height],
+ camera_widths=[camera_width],
+ reward_shaping=True,
+ control_freq=20,
+ )
+ obs_dict = env.reset()
+ sim = env.sim
+
+ # ground-truth object position
+ obj_pos = obs_dict["object-state"][:3]
+
+ # camera frame
+ image = obs_dict["{}_image".format(camera_name)][::-1]
+
+ # unnormalized depth map
+ depth_map = obs_dict["{}_depth".format(camera_name)][::-1]
+
+ depth_map = CU.get_real_depth_map(sim=env.sim, depth_map=depth_map)
+
+ # get camera matrices
+ world_to_camera = CU.get_camera_transform_matrix(
+ sim=env.sim,
+ camera_name=camera_name,
+ camera_height=camera_height,
+ camera_width=camera_width,
+ )
+ camera_to_world = np.linalg.inv(world_to_camera)
+
+ # transform object position into camera pixel
+ obj_pixel = CU.project_points_from_world_to_camera(
+ points=obj_pos,
+ world_to_camera_transform=world_to_camera,
+ camera_height=camera_height,
+ camera_width=camera_width,
+ )
+
+ # transform from camera pixel back to world position
+ estimated_obj_pos = CU.transform_from_pixels_to_world(
+ pixels=obj_pixel,
+ depth_map=depth_map,
+ camera_to_world_transform=camera_to_world,
+ )
+
+ # the most we should be off by in the z-direction is 3^0.5 times the maximum half-size of the cube
+ max_z_err = np.sqrt(3) * 0.022
+ z_err = np.abs(obj_pos[2] - estimated_obj_pos[2])
+ assert z_err < max_z_err
+
+ print("pixel: {}".format(obj_pixel))
+ print("obj pos: {}".format(obj_pos))
+ print("estimated obj pos: {}".format(estimated_obj_pos))
+ print("z err: {}".format(z_err))
+
+ env.close()
+
+
+if __name__ == "__main__":
+
+ test_camera_transforms()
diff --git a/phantom/submodules/phantom-robosuite/tests/test_grippers/test_all_grippers.py b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_all_grippers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ebae9cb04220b1e0e9f70daafe073695f49f377
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_all_grippers.py
@@ -0,0 +1,29 @@
+"""
+Tests the basic interface of all grippers.
+
+This runs some basic sanity checks on the environment, namely, checking that:
+ - Verifies that the gripper's action, init_qpos exist and are valid
+
+Obviously, if an environment crashes during runtime, that is considered a failure as well.
+"""
+from robosuite.models.grippers import GRIPPER_MAPPING
+
+
+def test_all_gripper():
+ for name, gripper in GRIPPER_MAPPING.items():
+ # Test all grippers except the null gripper
+ if name not in {None, "WipingGripper"}:
+ print("Testing {}...".format(name))
+ _test_gripper(gripper())
+
+
+def _test_gripper(gripper):
+ action = gripper.format_action([1] * gripper.dof)
+ assert action is not None
+
+ assert gripper.init_qpos is not None
+
+
+if __name__ == "__main__":
+ test_all_gripper()
+ print("Gripper tests completed.")
diff --git a/phantom/submodules/phantom-robosuite/tests/test_grippers/test_jaco_threefinger.py b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_jaco_threefinger.py
new file mode 100644
index 0000000000000000000000000000000000000000..d64f6ba8bb6910e852fa6505840bb2a1fcc2cc1e
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_jaco_threefinger.py
@@ -0,0 +1,26 @@
+from robosuite.models.grippers import GripperTester, JacoThreeFingerGripper
+
+
+def test_robotiq():
+ robotiq_tester(False)
+
+
+def robotiq_tester(render, total_iters=1, test_y=True):
+ gripper = JacoThreeFingerGripper()
+ tester = GripperTester(
+ gripper=gripper,
+ pos="0 0 0.3",
+ quat="0 0 1 0",
+ gripper_low_pos=0.01,
+ gripper_high_pos=0.1,
+ box_size=[0.025] * 3,
+ step_time=1000,
+ render=render,
+ )
+ tester.start_simulation()
+ tester.loop(total_iters=total_iters, test_y=test_y)
+ tester.close()
+
+
+if __name__ == "__main__":
+ robotiq_tester(True, 20, False)
diff --git a/phantom/submodules/phantom-robosuite/tests/test_grippers/test_panda_gripper.py b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_panda_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..3139ae2690639864c813547089149610968a0dd5
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_panda_gripper.py
@@ -0,0 +1,28 @@
+"""
+Tests panda gripper on grabbing task
+"""
+from robosuite.models.grippers import GripperTester, PandaGripper
+
+
+def test_panda_gripper():
+ panda_gripper_tester(False)
+
+
+def panda_gripper_tester(render, total_iters=1, test_y=True):
+ gripper = PandaGripper()
+ tester = GripperTester(
+ gripper=gripper,
+ pos="0 0 0.3",
+ quat="0 0 1 0",
+ gripper_low_pos=-0.10,
+ gripper_high_pos=0.01,
+ render=render,
+ )
+ tester.start_simulation()
+ tester.loop(total_iters=total_iters, test_y=test_y)
+ tester.close()
+
+
+if __name__ == "__main__":
+ panda_gripper_tester(True, 20, True)
+ panda_gripper_tester(True, 20, True)
diff --git a/phantom/submodules/phantom-robosuite/tests/test_grippers/test_rethink_gripper.py b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_rethink_gripper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b937878bd9c05646bc3236e736f1bae083533f2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_rethink_gripper.py
@@ -0,0 +1,27 @@
+"""
+Tests two finger gripper and left two finger gripper on grabbing task
+"""
+from robosuite.models.grippers import GripperTester, RethinkGripper
+
+
+def test_two_finger():
+ two_finger_tester(False)
+
+
+def two_finger_tester(render, total_iters=1, test_y=True):
+ gripper = RethinkGripper()
+ tester = GripperTester(
+ gripper=gripper,
+ pos="0 0 0.3",
+ quat="0 0 1 0",
+ gripper_low_pos=-0.07,
+ gripper_high_pos=0.02,
+ render=render,
+ )
+ tester.start_simulation()
+ tester.loop(total_iters=total_iters, test_y=test_y)
+ tester.close()
+
+
+if __name__ == "__main__":
+ two_finger_tester(True, 20, True)
diff --git a/phantom/submodules/phantom-robosuite/tests/test_grippers/test_robotiq_140.py b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_robotiq_140.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6235138beb779ed51df4aa9f2a4fd4a50aba029
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_robotiq_140.py
@@ -0,0 +1,25 @@
+from robosuite.models.grippers import GripperTester, Robotiq140Gripper
+
+
+def test_robotiq():
+ robotiq_tester(False)
+
+
+def robotiq_tester(render, total_iters=1, test_y=True):
+ gripper = Robotiq140Gripper()
+ tester = GripperTester(
+ gripper=gripper,
+ pos="0 0 0.3",
+ quat="0 0 1 0",
+ gripper_low_pos=0.02,
+ gripper_high_pos=0.1,
+ box_size=[0.025] * 3,
+ render=render,
+ )
+ tester.start_simulation()
+ tester.loop(total_iters=total_iters, test_y=test_y)
+ tester.close()
+
+
+if __name__ == "__main__":
+ robotiq_tester(True, 20, False)
diff --git a/phantom/submodules/phantom-robosuite/tests/test_grippers/test_robotiq_85.py b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_robotiq_85.py
new file mode 100644
index 0000000000000000000000000000000000000000..636b0b64ff8638fc971167bbcc05ca0b4e7a6a2f
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_robotiq_85.py
@@ -0,0 +1,25 @@
+from robosuite.models.grippers import GripperTester, Robotiq85Gripper
+
+
+def test_robotiq():
+ robotiq_tester(False)
+
+
+def robotiq_tester(render, total_iters=1, test_y=True):
+ gripper = Robotiq85Gripper()
+ tester = GripperTester(
+ gripper=gripper,
+ pos="-0.02 0 0.3",
+ quat="0 0 1 0",
+ gripper_low_pos=-0.065,
+ gripper_high_pos=0.01,
+ box_size=[0.025] * 3,
+ render=render,
+ )
+ tester.start_simulation()
+ tester.loop(total_iters=total_iters, test_y=test_y)
+ tester.close()
+
+
+if __name__ == "__main__":
+ robotiq_tester(True, 20, False)
diff --git a/phantom/submodules/phantom-robosuite/tests/test_grippers/test_robotiq_threefinger.py b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_robotiq_threefinger.py
new file mode 100644
index 0000000000000000000000000000000000000000..a040d7686d4d46bd361474a0bc511c005a94b6a2
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_grippers/test_robotiq_threefinger.py
@@ -0,0 +1,26 @@
+from robosuite.models.grippers import GripperTester, RobotiqThreeFingerGripper
+
+
+def test_robotiq_three_finger():
+ robotiq_three_finger_tester(False)
+
+
+def robotiq_three_finger_tester(render, total_iters=1, test_y=True):
+ gripper = RobotiqThreeFingerGripper()
+ tester = GripperTester(
+ gripper=gripper,
+ pos="0 0 0.3",
+ quat="0 0 1 0",
+ gripper_low_pos=-0.02,
+ gripper_high_pos=0.1,
+ box_size=[0.035] * 3,
+ box_density=500,
+ render=render,
+ )
+ tester.start_simulation()
+ tester.loop(total_iters=total_iters, test_y=test_y)
+ tester.close()
+
+
+if __name__ == "__main__":
+ robotiq_three_finger_tester(True, 20, False)
diff --git a/phantom/submodules/phantom-robosuite/tests/test_robots/test_all_robots.py b/phantom/submodules/phantom-robosuite/tests/test_robots/test_all_robots.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6340dd2d9c96a5617d9720c9c11556433cae3b
--- /dev/null
+++ b/phantom/submodules/phantom-robosuite/tests/test_robots/test_all_robots.py
@@ -0,0 +1,28 @@
+"""
+Tests the basic interface of all robots.
+
+This runs some basic sanity checks on the robots, namely, checking that:
+ - Verifies that all single-arm robots have properly defined contact geoms.
+
+Obviously, if an environment crashes during runtime, that is considered a failure as well.
+"""
+from robosuite.robots import ROBOT_CLASS_MAPPING, SingleArm
+
+
+def test_single_arm_robots():
+ for name, robot in ROBOT_CLASS_MAPPING.items():
+ if robot == SingleArm:
+ print(f"Testing {name}")
+ _test_contact_geoms(robot(name))
+
+
+def _test_contact_geoms(robot):
+ robot.load_model()
+ contact_geoms = robot.robot_model._contact_geoms
+ for geom in contact_geoms:
+ assert isinstance(geom, str), f"The geom {geom} is of type {type(geom)}, but should be {type('placeholder')}"
+
+
+if __name__ == "__main__":
+ test_single_arm_robots()
+ print("Robot tests completed.")
diff --git a/phantom/submodules/sam2/.clang-format b/phantom/submodules/sam2/.clang-format
new file mode 100644
index 0000000000000000000000000000000000000000..39b1b3d603ed0cf6b7f94c9c08067f148f35613f
--- /dev/null
+++ b/phantom/submodules/sam2/.clang-format
@@ -0,0 +1,85 @@
+AccessModifierOffset: -1
+AlignAfterOpenBracket: AlwaysBreak
+AlignConsecutiveAssignments: false
+AlignConsecutiveDeclarations: false
+AlignEscapedNewlinesLeft: true
+AlignOperands: false
+AlignTrailingComments: false
+AllowAllParametersOfDeclarationOnNextLine: false
+AllowShortBlocksOnASingleLine: false
+AllowShortCaseLabelsOnASingleLine: false
+AllowShortFunctionsOnASingleLine: Empty
+AllowShortIfStatementsOnASingleLine: false
+AllowShortLoopsOnASingleLine: false
+AlwaysBreakAfterReturnType: None
+AlwaysBreakBeforeMultilineStrings: true
+AlwaysBreakTemplateDeclarations: true
+BinPackArguments: false
+BinPackParameters: false
+BraceWrapping:
+ AfterClass: false
+ AfterControlStatement: false
+ AfterEnum: false
+ AfterFunction: false
+ AfterNamespace: false
+ AfterObjCDeclaration: false
+ AfterStruct: false
+ AfterUnion: false
+ BeforeCatch: false
+ BeforeElse: false
+ IndentBraces: false
+BreakBeforeBinaryOperators: None
+BreakBeforeBraces: Attach
+BreakBeforeTernaryOperators: true
+BreakConstructorInitializersBeforeComma: false
+BreakAfterJavaFieldAnnotations: false
+BreakStringLiterals: false
+ColumnLimit: 80
+CommentPragmas: '^ IWYU pragma:'
+ConstructorInitializerAllOnOneLineOrOnePerLine: true
+ConstructorInitializerIndentWidth: 4
+ContinuationIndentWidth: 4
+Cpp11BracedListStyle: true
+DerivePointerAlignment: false
+DisableFormat: false
+ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
+IncludeCategories:
+ - Regex: '^<.*\.h(pp)?>'
+ Priority: 1
+ - Regex: '^<.*'
+ Priority: 2
+ - Regex: '.*'
+ Priority: 3
+IndentCaseLabels: true
+IndentWidth: 2
+IndentWrappedFunctionNames: false
+KeepEmptyLinesAtTheStartOfBlocks: false
+MacroBlockBegin: ''
+MacroBlockEnd: ''
+MaxEmptyLinesToKeep: 1
+NamespaceIndentation: None
+ObjCBlockIndentWidth: 2
+ObjCSpaceAfterProperty: false
+ObjCSpaceBeforeProtocolList: false
+PenaltyBreakBeforeFirstCallParameter: 1
+PenaltyBreakComment: 300
+PenaltyBreakFirstLessLess: 120
+PenaltyBreakString: 1000
+PenaltyExcessCharacter: 1000000
+PenaltyReturnTypeOnItsOwnLine: 200
+PointerAlignment: Left
+ReflowComments: true
+SortIncludes: true
+SpaceAfterCStyleCast: false
+SpaceBeforeAssignmentOperators: true
+SpaceBeforeParens: ControlStatements
+SpaceInEmptyParentheses: false
+SpacesBeforeTrailingComments: 1
+SpacesInAngles: false
+SpacesInContainerLiterals: true
+SpacesInCStyleCastParentheses: false
+SpacesInParentheses: false
+SpacesInSquareBrackets: false
+Standard: Cpp11
+TabWidth: 8
+UseTab: Never
diff --git a/phantom/submodules/sam2/.github/workflows/check_fmt.yml b/phantom/submodules/sam2/.github/workflows/check_fmt.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0a29b884af2b5c0bdb71b607e7b8220e879755be
--- /dev/null
+++ b/phantom/submodules/sam2/.github/workflows/check_fmt.yml
@@ -0,0 +1,17 @@
+name: SAM2/fmt
+on:
+ pull_request:
+ branches:
+ - main
+jobs:
+ ufmt_check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Check formatting
+ uses: omnilib/ufmt@action-v1
+ with:
+ path: sam2 tools
+ version: "2.0.0b2"
+ python-version: "3.10"
+ black-version: "24.2.0"
+ usort-version: "1.0.2"
diff --git a/phantom/submodules/sam2/.gitignore b/phantom/submodules/sam2/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..121d46aa5c1854ee2b2a5085ff2ce22f5a04043f
--- /dev/null
+++ b/phantom/submodules/sam2/.gitignore
@@ -0,0 +1,11 @@
+.vscode/
+.DS_Store
+__pycache__/
+*-checkpoint.ipynb
+.venv
+*.egg*
+build/*
+_C.*
+outputs/*
+checkpoints/*.pt
+demo/backend/checkpoints/*.pt
diff --git a/phantom/submodules/sam2/.watchmanconfig b/phantom/submodules/sam2/.watchmanconfig
new file mode 100644
index 0000000000000000000000000000000000000000..9e26dfeeb6e641a33dae4961196235bdb965b21b
--- /dev/null
+++ b/phantom/submodules/sam2/.watchmanconfig
@@ -0,0 +1 @@
+{}
\ No newline at end of file
diff --git a/phantom/submodules/sam2/CODE_OF_CONDUCT.md b/phantom/submodules/sam2/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..08b500a221857ec3f451338e80b4a9ab1173a1af
--- /dev/null
+++ b/phantom/submodules/sam2/CODE_OF_CONDUCT.md
@@ -0,0 +1,80 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+This Code of Conduct also applies outside the project spaces when there is a
+reasonable belief that an individual's behavior may have a negative impact on
+the project or its community.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at . All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/phantom/submodules/sam2/CONTRIBUTING.md b/phantom/submodules/sam2/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..ad15049f583e1bc9a418686493405875b98c7f0f
--- /dev/null
+++ b/phantom/submodules/sam2/CONTRIBUTING.md
@@ -0,0 +1,31 @@
+# Contributing to segment-anything
+We want to make contributing to this project as easy and transparent as
+possible.
+
+## Pull Requests
+We actively welcome your pull requests.
+
+1. Fork the repo and create your branch from `main`.
+2. If you've added code that should be tested, add tests.
+3. If you've changed APIs, update the documentation.
+4. Ensure the test suite passes.
+5. Make sure your code lints, using the `ufmt format` command. Linting requires `black==24.2.0`, `usort==1.0.2`, and `ufmt==2.0.0b2`, which can be installed via `pip install -e ".[dev]"`.
+6. If you haven't already, complete the Contributor License Agreement ("CLA").
+
+## Contributor License Agreement ("CLA")
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Facebook's open source projects.
+
+Complete your CLA here:
+
+## Issues
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
+disclosure of security bugs. In those cases, please go through the process
+outlined on that page and do not file a public issue.
+
+## License
+By contributing to segment-anything, you agree that your contributions will be licensed
+under the LICENSE file in the root directory of this source tree.
diff --git a/phantom/submodules/sam2/INSTALL.md b/phantom/submodules/sam2/INSTALL.md
new file mode 100644
index 0000000000000000000000000000000000000000..9480ba1bb52c171cfccc6a078c68abdb49125daa
--- /dev/null
+++ b/phantom/submodules/sam2/INSTALL.md
@@ -0,0 +1,189 @@
+## Installation
+
+### Requirements
+
+- Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
+ * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
+- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
+- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
+
+Then, install SAM 2 from the root of this repository via
+```bash
+pip install -e ".[notebooks]"
+```
+
+Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows:
+```bash
+# skip the SAM 2 CUDA extension
+SAM2_BUILD_CUDA=0 pip install -e ".[notebooks]"
+```
+This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases.
+
+### Building the SAM 2 CUDA extension
+
+By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.)
+
+If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, **you can still use SAM 2 for both image and video applications**. The post-processing step (removing small holes and sprinkles in the output masks) will be skipped, but this shouldn't affect the results in most cases.
+
+If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows
+```bash
+pip uninstall -y SAM-2 && \
+rm -f ./sam2/*.so && \
+SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[notebooks]"
+```
+
+Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`.
+
+Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime.
+
+### Common Installation Issues
+
+Click each issue for its solutions:
+
+
+
+I got `ImportError: cannot import name '_C' from 'sam2'`
+
+
+
+This is usually because you haven't run the `pip install -e ".[notebooks]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails.
+
+In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/sam2/issues/77.
+
+
+
+
+I got `MissingConfigException: Cannot find primary config 'configs/sam2.1/sam2.1_hiera_l.yaml'`
+
+
+
+This is usually because you haven't run the `pip install -e .` step above, so `sam2` isn't in your Python's `sys.path`. Please run this installation step. In case it still fails after the installation step, you may try manually adding the root of this repo to `PYTHONPATH` via
+```bash
+export SAM2_REPO_ROOT=/path/to/sam2 # path to this repo
+export PYTHONPATH="${SAM2_REPO_ROOT}:${PYTHONPATH}"
+```
+to manually add `sam2_configs` into your Python's `sys.path`.
+
+
+
+
+
+I got `RuntimeError: Error(s) in loading state_dict for SAM2Base` when loading the new SAM 2.1 checkpoints
+
+
+
+This is likely because you have installed a previous version of this repo, which doesn't have the new modules to support the SAM 2.1 checkpoints yet. Please try the following steps:
+
+1. pull the latest code from the `main` branch of this repo
+2. run `pip uninstall -y SAM-2` to uninstall any previous installations
+3. then install the latest repo again using `pip install -e ".[notebooks]"`
+
+In case the steps above still don't resolve the error, please try running in your Python environment the following
+```python
+from sam2.modeling import sam2_base
+
+print(sam2_base.__file__)
+```
+and check whether the content in the printed local path of `sam2/modeling/sam2_base.py` matches the latest one in https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/sam2_base.py (e.g. whether your local file has `no_obj_embed_spatial`) to indentify if you're still using a previous installation.
+
+
+
+
+
+My installation failed with `CUDA_HOME environment variable is not set`
+
+
+
+This usually happens because the installation step cannot find the CUDA toolkits (that contain the NVCC compiler) to build a custom CUDA kernel in SAM 2. Please install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) or the version that matches the CUDA version for your PyTorch installation. If the error persists after installing CUDA toolkits, you may explicitly specify `CUDA_HOME` via
+```
+export CUDA_HOME=/usr/local/cuda # change to your CUDA toolkit path
+```
+and rerun the installation.
+
+Also, you should make sure
+```
+python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)'
+```
+print `(True, a directory with cuda)` to verify that the CUDA toolkits are correctly set up.
+
+If you are still having problems after verifying that the CUDA toolkit is installed and the `CUDA_HOME` environment variable is set properly, you may have to add the `--no-build-isolation` flag to the pip command:
+```
+pip install --no-build-isolation -e .
+```
+
+
+
+
+
+I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar errors)
+
+
+
+This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
+
+In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
+
+We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
+
+
+
+
+I got `CUDA error: no kernel image is available for execution on the device`
+
+
+
+A possible cause could be that the CUDA kernel is somehow not compiled towards your GPU's CUDA [capability](https://developer.nvidia.com/cuda-gpus). This could happen if the installation is done in an environment different from the runtime (e.g. in a slurm system).
+
+You can try pulling the latest code from the SAM 2 repo and running the following
+```
+export TORCH_CUDA_ARCH_LIST=9.0 8.0 8.6 8.9 7.0 7.2 7.5 6.0`
+```
+to manually specify the CUDA capability in the compilation target that matches your GPU.
+
+
+
+
+I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors)
+
+
+
+This is probably because your machine doesn't have a GPU or a compatible PyTorch version for Flash Attention (see also https://discuss.pytorch.org/t/using-f-scaled-dot-product-attention-gives-the-error-runtimeerror-no-available-kernel-aborting-execution/180900 for a discussion in PyTorch forum). You may be able to resolve this error by replacing the line
+```python
+OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
+```
+in [`sam2/modeling/sam/transformer.py`](sam2/modeling/sam/transformer.py) with
+```python
+OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True
+```
+to relax the attention kernel setting and use other kernels than Flash Attention.
+
+
+
+
+I got `Error compiling objects for extension`
+
+
+
+You may see error log of:
+> unsupported Microsoft Visual Studio version! Only the versions between 2017 and 2022 (inclusive) are supported! The nvcc flag '-allow-unsupported-compiler' can be used to override this version check; however, using an unsupported host compiler may cause compilation failure or incorrect run time execution. Use at your own risk.
+
+This is probably because your versions of CUDA and Visual Studio are incompatible. (see also https://stackoverflow.com/questions/78515942/cuda-compatibility-with-visual-studio-2022-version-17-10 for a discussion in stackoverflow).
+You may be able to fix this by adding the `-allow-unsupported-compiler` argument to `nvcc` after L48 in the [setup.py](https://github.com/facebookresearch/sam2/blob/main/setup.py).
+After adding the argument, `get_extension()` will look like this:
+```python
+def get_extensions():
+ srcs = ["sam2/csrc/connected_components.cu"]
+ compile_args = {
+ "cxx": [],
+ "nvcc": [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ "-allow-unsupported-compiler" # Add this argument
+ ],
+ }
+ ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
+ return ext_modules
+```
+
diff --git a/phantom/submodules/sam2/LICENSE b/phantom/submodules/sam2/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/phantom/submodules/sam2/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/phantom/submodules/sam2/LICENSE_cctorch b/phantom/submodules/sam2/LICENSE_cctorch
new file mode 100644
index 0000000000000000000000000000000000000000..23da14a65aad4c5bac18061b80ae6040bb7d2c8c
--- /dev/null
+++ b/phantom/submodules/sam2/LICENSE_cctorch
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/phantom/submodules/sam2/MANIFEST.in b/phantom/submodules/sam2/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..794311fd9854453b134c828c0cb241a7cfdbfc65
--- /dev/null
+++ b/phantom/submodules/sam2/MANIFEST.in
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+recursive-include sam2 *.yaml #include all config files
diff --git a/phantom/submodules/sam2/README.md b/phantom/submodules/sam2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..85a7eb958bced5495ff990c2bcbe7d99662c660f
--- /dev/null
+++ b/phantom/submodules/sam2/README.md
@@ -0,0 +1,224 @@
+# SAM 2: Segment Anything in Images and Videos
+
+**[AI at Meta, FAIR](https://ai.meta.com/research/)**
+
+[Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/)
+
+[[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)]
+
+
+
+**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.
+
+
+
+## Latest updates
+
+**12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking**
+
+- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor`, leading to a major speedup for VOS inference.
+- We update the implementation of `SAM2VideoPredictor` to support independent per-object inference, allowing us to relax the assumption of prompting for multi-object tracking and adding new objects after tracking starts.
+- See [`RELEASE_NOTES.md`](RELEASE_NOTES.md) for full details.
+
+**09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released**
+
+- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
+ * To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
+- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
+- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
+
+## Installation
+
+SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.5.1` and `torchvision>=0.20.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:
+
+```bash
+git clone https://github.com/facebookresearch/sam2.git && cd sam2
+
+pip install -e .
+```
+If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
+
+To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:
+
+```bash
+pip install -e ".[notebooks]"
+```
+
+Note:
+1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.5.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.5.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
+2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version.
+3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases).
+
+Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions.
+
+## Getting Started
+
+### Download Checkpoints
+
+First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:
+
+```bash
+cd checkpoints && \
+./download_ckpts.sh && \
+cd ..
+```
+
+or individually from:
+
+- [sam2.1_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)
+- [sam2.1_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)
+- [sam2.1_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)
+- [sam2.1_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)
+
+(note that these are the improved checkpoints denoted as SAM 2.1; see [Model Description](#model-description) for details.)
+
+Then SAM 2 can be used in a few lines as follows for image and video prediction.
+
+### Image prediction
+
+SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting.
+
+```python
+import torch
+from sam2.build_sam import build_sam2
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+
+checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
+model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
+predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
+
+with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
+ predictor.set_image()
+ masks, _, _ = predictor.predict()
+```
+
+Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases.
+
+SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images.
+
+### Video prediction
+
+For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video.
+
+```python
+import torch
+from sam2.build_sam import build_sam2_video_predictor
+
+checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
+model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
+predictor = build_sam2_video_predictor(model_cfg, checkpoint)
+
+with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
+ state = predictor.init_state()
+
+ # add new prompts and instantly get the output on the same frame
+ frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ):
+
+ # propagate the prompts to get masklets throughout the video
+ for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
+ ...
+```
+
+Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
+
+## Load from 🤗 Hugging Face
+
+Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`).
+
+For image prediction:
+
+```python
+import torch
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+
+predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
+
+with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
+ predictor.set_image()
+ masks, _, _ = predictor.predict()
+```
+
+For video prediction:
+
+```python
+import torch
+from sam2.sam2_video_predictor import SAM2VideoPredictor
+
+predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
+
+with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
+ state = predictor.init_state()
+
+ # add new prompts and instantly get the output on the same frame
+ frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ):
+
+ # propagate the prompts to get masklets throughout the video
+ for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
+ ...
+```
+
+## Model Description
+
+### SAM 2.1 checkpoints
+
+The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
+| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
+| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
+| sam2.1_hiera_tiny ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 |
+| sam2.1_hiera_small ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 |
+| sam2.1_hiera_base_plus ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 |
+| sam2.1_hiera_large ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 |
+
+### SAM 2 checkpoints
+
+The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows:
+
+| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
+| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
+| sam2_hiera_tiny ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 |
+| sam2_hiera_small ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 |
+| sam2_hiera_base_plus ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 |
+| sam2_hiera_large ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 |
+
+Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config).
+## Segment Anything Video Dataset
+
+See [sav_dataset/README.md](sav_dataset/README.md) for details.
+
+## Training SAM 2
+
+You can train or fine-tune SAM 2 on custom datasets of images, videos, or both. Please check the training [README](training/README.md) on how to get started.
+
+## Web demo for SAM 2
+
+We have released the frontend + backend code for the SAM 2 web demo (a locally deployable version similar to https://sam2.metademolab.com/demo). Please see the web demo [README](demo/README.md) for details.
+
+## License
+
+The SAM 2 model checkpoints, SAM 2 demo code (front-end and back-end), and SAM 2 training code are licensed under [Apache 2.0](./LICENSE), however the [Inter Font](https://github.com/rsms/inter?tab=OFL-1.1-1-ov-file) and [Noto Color Emoji](https://github.com/googlefonts/noto-emoji) used in the SAM 2 demo code are made available under the [SIL Open Font License, version 1.1](https://openfontlicense.org/open-font-license-official-text/).
+
+## Contributing
+
+See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
+
+## Contributors
+
+The SAM 2 project was made possible with the help of many contributors (alphabetical):
+
+Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang.
+
+Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions.
+
+## Citing SAM 2
+
+If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry.
+
+```bibtex
+@article{ravi2024sam2,
+ title={SAM 2: Segment Anything in Images and Videos},
+ author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
+ journal={arXiv preprint arXiv:2408.00714},
+ url={https://arxiv.org/abs/2408.00714},
+ year={2024}
+}
+```
diff --git a/phantom/submodules/sam2/RELEASE_NOTES.md b/phantom/submodules/sam2/RELEASE_NOTES.md
new file mode 100644
index 0000000000000000000000000000000000000000..ee65ae7f4a51f1c7fce81204c5dc94467882d366
--- /dev/null
+++ b/phantom/submodules/sam2/RELEASE_NOTES.md
@@ -0,0 +1,27 @@
+## SAM 2 release notes
+
+### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking
+
+- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
+ * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
+ * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
+ * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
+ * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
+- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
+ * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
+ * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
+ * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
+
+### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released
+
+- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
+ * To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
+- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
+- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
+
+### 07/29/2024 -- SAM 2 is released
+
+- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos.
+ * SAM 2 code: https://github.com/facebookresearch/sam2
+ * SAM 2 demo: https://sam2.metademolab.com/
+ * SAM 2 paper: https://arxiv.org/abs/2408.00714
diff --git a/phantom/submodules/sam2/backend.Dockerfile b/phantom/submodules/sam2/backend.Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..54a32967b0e053ae2e7d16c734928636ef46db7b
--- /dev/null
+++ b/phantom/submodules/sam2/backend.Dockerfile
@@ -0,0 +1,64 @@
+ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
+ARG MODEL_SIZE=base_plus
+
+FROM ${BASE_IMAGE}
+
+# Gunicorn environment variables
+ENV GUNICORN_WORKERS=1
+ENV GUNICORN_THREADS=2
+ENV GUNICORN_PORT=5000
+
+# SAM 2 environment variables
+ENV APP_ROOT=/opt/sam2
+ENV PYTHONUNBUFFERED=1
+ENV SAM2_BUILD_CUDA=0
+ENV MODEL_SIZE=${MODEL_SIZE}
+
+# Install system requirements
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ ffmpeg \
+ libavutil-dev \
+ libavcodec-dev \
+ libavformat-dev \
+ libswscale-dev \
+ pkg-config \
+ build-essential \
+ libffi-dev
+
+COPY setup.py .
+COPY README.md .
+
+RUN pip install --upgrade pip setuptools
+RUN pip install -e ".[interactive-demo]"
+
+# https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/issues/69#issuecomment-1826764707
+RUN rm /opt/conda/bin/ffmpeg && ln -s /bin/ffmpeg /opt/conda/bin/ffmpeg
+
+# Make app directory. This directory will host all files required for the
+# backend and SAM 2 inference files.
+RUN mkdir ${APP_ROOT}
+
+# Copy backend server files
+COPY demo/backend/server ${APP_ROOT}/server
+
+# Copy SAM 2 inference files
+COPY sam2 ${APP_ROOT}/server/sam2
+
+# Download SAM 2.1 checkpoints
+ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_tiny.pt
+ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_small.pt
+ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_base_plus.pt
+ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_large.pt
+
+WORKDIR ${APP_ROOT}/server
+
+# https://pythonspeed.com/articles/gunicorn-in-docker/
+CMD gunicorn --worker-tmp-dir /dev/shm \
+ --worker-class gthread app:app \
+ --log-level info \
+ --access-logfile /dev/stdout \
+ --log-file /dev/stderr \
+ --workers ${GUNICORN_WORKERS} \
+ --threads ${GUNICORN_THREADS} \
+ --bind 0.0.0.0:${GUNICORN_PORT} \
+ --timeout 60
diff --git a/phantom/submodules/sam2/checkpoints/download_ckpts.sh b/phantom/submodules/sam2/checkpoints/download_ckpts.sh
new file mode 100755
index 0000000000000000000000000000000000000000..eedee8eee153f17c6db3b92de5492fa0a11ec3b7
--- /dev/null
+++ b/phantom/submodules/sam2/checkpoints/download_ckpts.sh
@@ -0,0 +1,59 @@
+#!/bin/bash
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Use either wget or curl to download the checkpoints
+if command -v wget &> /dev/null; then
+ CMD="wget"
+elif command -v curl &> /dev/null; then
+ CMD="curl -L -O"
+else
+ echo "Please install wget or curl to download the checkpoints."
+ exit 1
+fi
+
+# Define the URLs for SAM 2 checkpoints
+# SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
+# sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
+# sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
+# sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
+# sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"
+
+# Download each of the four checkpoints using wget
+# echo "Downloading sam2_hiera_tiny.pt checkpoint..."
+# $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
+
+# echo "Downloading sam2_hiera_small.pt checkpoint..."
+# $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
+
+# echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
+# $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
+
+# echo "Downloading sam2_hiera_large.pt checkpoint..."
+# $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
+
+# Define the URLs for SAM 2.1 checkpoints
+SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
+sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
+sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
+sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
+sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"
+
+# SAM 2.1 checkpoints
+echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
+$CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }
+
+echo "Downloading sam2.1_hiera_small.pt checkpoint..."
+$CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }
+
+echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
+$CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }
+
+echo "Downloading sam2.1_hiera_large.pt checkpoint..."
+$CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
+
+echo "All checkpoints are downloaded successfully."
diff --git a/phantom/submodules/sam2/docker-compose.yaml b/phantom/submodules/sam2/docker-compose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7a5395a585daa7d5a6e0e97d3a30b48f225fb2cf
--- /dev/null
+++ b/phantom/submodules/sam2/docker-compose.yaml
@@ -0,0 +1,42 @@
+services:
+ frontend:
+ image: sam2/frontend
+ build:
+ context: ./demo/frontend
+ dockerfile: frontend.Dockerfile
+ ports:
+ - 7262:80
+
+ backend:
+ image: sam2/backend
+ build:
+ context: .
+ dockerfile: backend.Dockerfile
+ ports:
+ - 7263:5000
+ volumes:
+ - ./demo/data/:/data/:rw
+ environment:
+ - SERVER_ENVIRONMENT=DEV
+ - GUNICORN_WORKERS=1
+ # Inference API needs to have at least 2 threads to handle an incoming
+ # parallel cancel propagation request
+ - GUNICORN_THREADS=2
+ - GUNICORN_PORT=5000
+ - API_URL=http://localhost:7263
+ - DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4
+ # # ffmpeg/video encode settings
+ - FFMPEG_NUM_THREADS=1
+ - VIDEO_ENCODE_CODEC=libx264
+ - VIDEO_ENCODE_CRF=23
+ - VIDEO_ENCODE_FPS=24
+ - VIDEO_ENCODE_MAX_WIDTH=1280
+ - VIDEO_ENCODE_MAX_HEIGHT=720
+ - VIDEO_ENCODE_VERBOSE=False
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
diff --git a/phantom/submodules/sam2/pyproject.toml b/phantom/submodules/sam2/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..f84317dbbfa6ba4f2d972cab2e2e0d0bdf07f003
--- /dev/null
+++ b/phantom/submodules/sam2/pyproject.toml
@@ -0,0 +1,6 @@
+[build-system]
+requires = [
+ "setuptools>=61.0",
+ "torch>=2.5.1",
+ ]
+build-backend = "setuptools.build_meta"
diff --git a/phantom/submodules/sam2/sam2/__init__.py b/phantom/submodules/sam2/sam2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0712dd03cb280ab94ba04f8a32aa8ddc8aa3db4a
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from hydra import initialize_config_module
+from hydra.core.global_hydra import GlobalHydra
+
+if not GlobalHydra.instance().is_initialized():
+ initialize_config_module("sam2", version_base="1.2")
diff --git a/phantom/submodules/sam2/sam2/automatic_mask_generator.py b/phantom/submodules/sam2/sam2/automatic_mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..065e469e27c2d3af40d51d072031e828692c799b
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/automatic_mask_generator.py
@@ -0,0 +1,454 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area # type: ignore
+
+from sam2.modeling.sam2_base import SAM2Base
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from sam2.utils.amg import (
+ area_from_rle,
+ batch_iterator,
+ batched_mask_to_box,
+ box_xyxy_to_xywh,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ coco_encode_rle,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ mask_to_rle_pytorch,
+ MaskData,
+ remove_small_regions,
+ rle_to_mask,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+ uncrop_points,
+)
+
+
+class SAM2AutomaticMaskGenerator:
+ def __init__(
+ self,
+ model: SAM2Base,
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.8,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ mask_threshold: float = 0.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+ use_m2m: bool = False,
+ multimask_output: bool = True,
+ **kwargs,
+ ) -> None:
+ """
+ Using a SAM 2 model, generates masks for the entire image.
+ Generates a grid of point prompts over the image, then filters
+ low quality and duplicate masks. The default settings are chosen
+ for SAM 2 with a HieraL backbone.
+
+ Arguments:
+ model (Sam): The SAM 2 model to use for mask prediction.
+ points_per_side (int or None): The number of points to be sampled
+ along one side of the image. The total number of points is
+ points_per_side**2. If None, 'point_grids' must provide explicit
+ point sampling.
+ points_per_batch (int): Sets the number of points run simultaneously
+ by the model. Higher numbers may be faster but use more GPU memory.
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
+ model's predicted mask quality.
+ stability_score_thresh (float): A filtering threshold in [0,1], using
+ the stability of the mask under changes to the cutoff used to binarize
+ the model's mask predictions.
+ stability_score_offset (float): The amount to shift the cutoff when
+ calculated the stability score.
+ mask_threshold (float): Threshold for binarizing the mask logits
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks.
+ crop_n_layers (int): If >0, mask prediction will be run again on
+ crops of the image. Sets the number of layers to run, where each
+ layer has 2**i_layer number of image crops.
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks between different crops.
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
+ In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ crop_n_points_downscale_factor (int): The number of points-per-side
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ point_grids (list(np.ndarray) or None): A list over explicit grids
+ of points used for sampling, normalized to [0,1]. The nth grid in the
+ list is used in the nth crop layer. Exclusive with points_per_side.
+ min_mask_region_area (int): If >0, postprocessing will be applied
+ to remove disconnected regions and holes in masks with area smaller
+ than min_mask_region_area. Requires opencv.
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+ For large resolutions, 'binary_mask' may consume large amounts of
+ memory.
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
+ multimask_output (bool): Whether to output multimask at each point of the grid.
+ """
+
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), "Exactly one of points_per_side or point_grid must be provided."
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ "binary_mask",
+ "uncompressed_rle",
+ "coco_rle",
+ ], f"Unknown output_mode {output_mode}."
+ if output_mode == "coco_rle":
+ try:
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
+ except ImportError as e:
+ print("Please install pycocotools")
+ raise e
+
+ self.predictor = SAM2ImagePredictor(
+ model,
+ max_hole_area=min_mask_region_area,
+ max_sprinkle_area=min_mask_region_area,
+ )
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.mask_threshold = mask_threshold
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+ self.use_m2m = use_m2m
+ self.multimask_output = multimask_output
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
+ """
+ Load a pretrained model from the Hugging Face hub.
+
+ Arguments:
+ model_id (str): The Hugging Face repository ID.
+ **kwargs: Additional arguments to pass to the model constructor.
+
+ Returns:
+ (SAM2AutomaticMaskGenerator): The loaded model.
+ """
+ from sam2.build_sam import build_sam2_hf
+
+ sam_model = build_sam2_hf(model_id, **kwargs)
+ return cls(sam_model, **kwargs)
+
+ @torch.no_grad()
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ """
+ Generates masks for the given image.
+
+ Arguments:
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+ Returns:
+ list(dict(str, any)): A list over records for masks. Each record is
+ a dict containing the following keys:
+ segmentation (dict(str, any) or np.ndarray): The mask. If
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
+ is a dictionary containing the RLE.
+ bbox (list(float)): The box around the mask, in XYWH format.
+ area (int): The area in pixels of the mask.
+ predicted_iou (float): The model's own prediction of the mask's
+ quality. This is filtered by the pred_iou_thresh parameter.
+ point_coords (list(list(float))): The point coordinates input
+ to the model to generate this mask.
+ stability_score (float): A measure of the mask's quality. This
+ is filtered on using the stability_score_thresh parameter.
+ crop_box (list(float)): The crop of the image used to generate
+ the mask, given in XYWH format.
+ """
+
+ # Generate masks
+ mask_data = self._generate_masks(image)
+
+ # Encode masks
+ if self.output_mode == "coco_rle":
+ mask_data["segmentations"] = [
+ coco_encode_rle(rle) for rle in mask_data["rles"]
+ ]
+ elif self.output_mode == "binary_mask":
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+ else:
+ mask_data["segmentations"] = mask_data["rles"]
+
+ # Write mask records
+ curr_anns = []
+ for idx in range(len(mask_data["segmentations"])):
+ ann = {
+ "segmentation": mask_data["segmentations"][idx],
+ "area": area_from_rle(mask_data["rles"][idx]),
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
+ "point_coords": [mask_data["points"][idx].tolist()],
+ "stability_score": mask_data["stability_score"][idx].item(),
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+ }
+ curr_anns.append(ann)
+
+ return curr_anns
+
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
+ orig_size = image.shape[:2]
+ crop_boxes, layer_idxs = generate_crop_boxes(
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
+ )
+
+ # Iterate over image crops
+ data = MaskData()
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+ data.cat(crop_data)
+
+ # Remove duplicate masks between crops
+ if len(crop_boxes) > 1:
+ # Prefer masks from smaller crops
+ scores = 1 / box_area(data["crop_boxes"])
+ scores = scores.to(data["boxes"].device)
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ scores,
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.crop_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+ data.to_numpy()
+ return data
+
+ def _process_crop(
+ self,
+ image: np.ndarray,
+ crop_box: List[int],
+ crop_layer_idx: int,
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ # Crop the image and calculate embeddings
+ x0, y0, x1, y1 = crop_box
+ cropped_im = image[y0:y1, x0:x1, :]
+ cropped_im_size = cropped_im.shape[:2]
+ self.predictor.set_image(cropped_im)
+
+ # Get points for this crop
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+ # Generate masks for this crop in batches
+ data = MaskData()
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+ batch_data = self._process_batch(
+ points, cropped_im_size, crop_box, orig_size, normalize=True
+ )
+ data.cat(batch_data)
+ del batch_data
+ self.predictor.reset_predictor()
+
+ # Remove duplicates within this crop.
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ data["iou_preds"],
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.box_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ # Return to the original image frame
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+ data["points"] = uncrop_points(data["points"], crop_box)
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+ return data
+
+ def _process_batch(
+ self,
+ points: np.ndarray,
+ im_size: Tuple[int, ...],
+ crop_box: List[int],
+ orig_size: Tuple[int, ...],
+ normalize=False,
+ ) -> MaskData:
+ orig_h, orig_w = orig_size
+
+ # Run model on this batch
+ points = torch.as_tensor(
+ points, dtype=torch.float32, device=self.predictor.device
+ )
+ in_points = self.predictor._transforms.transform_coords(
+ points, normalize=normalize, orig_hw=im_size
+ )
+ in_labels = torch.ones(
+ in_points.shape[0], dtype=torch.int, device=in_points.device
+ )
+ masks, iou_preds, low_res_masks = self.predictor._predict(
+ in_points[:, None, :],
+ in_labels[:, None],
+ multimask_output=self.multimask_output,
+ return_logits=True,
+ )
+
+ # Serialize predictions and store in MaskData
+ data = MaskData(
+ masks=masks.flatten(0, 1),
+ iou_preds=iou_preds.flatten(0, 1),
+ points=points.repeat_interleave(masks.shape[1], dim=0),
+ low_res_masks=low_res_masks.flatten(0, 1),
+ )
+ del masks
+
+ if not self.use_m2m:
+ # Filter by predicted IoU
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ # Calculate and filter by stability score
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+ else:
+ # One step refinement using previous mask predictions
+ in_points = self.predictor._transforms.transform_coords(
+ data["points"], normalize=normalize, orig_hw=im_size
+ )
+ labels = torch.ones(
+ in_points.shape[0], dtype=torch.int, device=in_points.device
+ )
+ masks, ious = self.refine_with_m2m(
+ in_points, labels, data["low_res_masks"], self.points_per_batch
+ )
+ data["masks"] = masks.squeeze(1)
+ data["iou_preds"] = ious.squeeze(1)
+
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+
+ # Threshold masks and calculate boxes
+ data["masks"] = data["masks"] > self.mask_threshold
+ data["boxes"] = batched_mask_to_box(data["masks"])
+
+ # Filter boxes that touch crop boundaries
+ keep_mask = ~is_box_near_crop_edge(
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
+ )
+ if not torch.all(keep_mask):
+ data.filter(keep_mask)
+
+ # Compress to RLE
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
+ del data["masks"]
+
+ return data
+
+ @staticmethod
+ def postprocess_small_regions(
+ mask_data: MaskData, min_area: int, nms_thresh: float
+ ) -> MaskData:
+ """
+ Removes small disconnected regions and holes in masks, then reruns
+ box NMS to remove any new duplicates.
+
+ Edits mask_data in place.
+
+ Requires open-cv as a dependency.
+ """
+ if len(mask_data["rles"]) == 0:
+ return mask_data
+
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for rle in mask_data["rles"]:
+ mask = rle_to_mask(rle)
+
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
+ unchanged = not changed
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
+ unchanged = unchanged and not changed
+
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and score=1 to unchanged masks
+ # so NMS will prefer ones that didn't need postprocessing
+ scores.append(float(unchanged))
+
+ # Recalculate boxes and remove any new duplicates
+ masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(masks)
+ keep_by_nms = batched_nms(
+ boxes.float(),
+ torch.as_tensor(scores),
+ torch.zeros_like(boxes[:, 0]), # categories
+ iou_threshold=nms_thresh,
+ )
+
+ # Only recalculate RLEs for masks that have changed
+ for i_mask in keep_by_nms:
+ if scores[i_mask] == 0.0:
+ mask_torch = masks[i_mask].unsqueeze(0)
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
+ mask_data.filter(keep_by_nms)
+
+ return mask_data
+
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
+ new_masks = []
+ new_iou_preds = []
+
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
+ points_per_batch, points, point_labels, low_res_masks
+ ):
+ best_masks, best_iou_preds, _ = self.predictor._predict(
+ cur_points[:, None, :],
+ cur_point_labels[:, None],
+ mask_input=low_res_mask[:, None, :],
+ multimask_output=False,
+ return_logits=True,
+ )
+ new_masks.append(best_masks)
+ new_iou_preds.append(best_iou_preds)
+ masks = torch.cat(new_masks, dim=0)
+ return masks, torch.cat(new_iou_preds, dim=0)
diff --git a/phantom/submodules/sam2/sam2/benchmark.py b/phantom/submodules/sam2/sam2/benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6519534c8619e04b9a632859a5128ad2cee34c13
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/benchmark.py
@@ -0,0 +1,92 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import time
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from sam2.build_sam import build_sam2_video_predictor
+
+# Only cuda supported
+assert torch.cuda.is_available()
+device = torch.device("cuda")
+
+torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
+if torch.cuda.get_device_properties(0).major >= 8:
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+# Config and checkpoint
+sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
+model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
+
+# Build video predictor with vos_optimized=True setting
+predictor = build_sam2_video_predictor(
+ model_cfg, sam2_checkpoint, device=device, vos_optimized=True
+)
+
+
+# Initialize with video
+video_dir = "notebooks/videos/bedroom"
+# scan all the JPEG frame names in this directory
+frame_names = [
+ p
+ for p in os.listdir(video_dir)
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
+]
+frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
+inference_state = predictor.init_state(video_path=video_dir)
+
+
+# Number of runs, warmup etc
+warm_up, runs = 5, 25
+verbose = True
+num_frames = len(frame_names)
+total, count = 0, 0
+torch.cuda.empty_cache()
+
+# We will select an object with a click.
+# See video_predictor_example.ipynb for more detailed explanation
+ann_frame_idx, ann_obj_id = 0, 1
+# Add a positive click at (x, y) = (210, 350)
+# For labels, `1` means positive click
+points = np.array([[210, 350]], dtype=np.float32)
+labels = np.array([1], np.int32)
+
+_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
+ inference_state=inference_state,
+ frame_idx=ann_frame_idx,
+ obj_id=ann_obj_id,
+ points=points,
+ labels=labels,
+)
+
+# Warmup and then average FPS over several runs
+with torch.autocast("cuda", torch.bfloat16):
+ with torch.inference_mode():
+ for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
+ start = time.time()
+ # Start tracking
+ for (
+ out_frame_idx,
+ out_obj_ids,
+ out_mask_logits,
+ ) in predictor.propagate_in_video(inference_state):
+ pass
+
+ end = time.time()
+ total += end - start
+ count += 1
+ if i == warm_up - 1:
+ print("Warmup FPS: ", count * num_frames / total)
+ total = 0
+ count = 0
+
+print("FPS: ", count * num_frames / total)
diff --git a/phantom/submodules/sam2/sam2/build_sam.py b/phantom/submodules/sam2/sam2/build_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a3bef1e566d86c3ba0fd75f425530bc6505e9bf
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/build_sam.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+
+import torch
+from hydra import compose
+from hydra.utils import instantiate
+from omegaconf import OmegaConf
+
+import sam2
+
+# Check if the user is running Python from the parent directory of the sam2 repo
+# (i.e. the directory where this repo is cloned into) -- this is not supported since
+# it could shadow the sam2 package and cause issues.
+if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
+ # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
+ # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
+ # This typically happens because the user is running Python from the parent directory
+ # that contains the sam2 repo they cloned.
+ raise RuntimeError(
+ "You're likely running Python from the parent directory of the sam2 repository "
+ "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
+ "This is not supported since the `sam2` Python package could be shadowed by the "
+ "repository name (the repository is also named `sam2` and contains the Python package "
+ "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
+ "rather than its parent dir, or from your home directory) after installing SAM 2."
+ )
+
+
+HF_MODEL_ID_TO_FILENAMES = {
+ "facebook/sam2-hiera-tiny": (
+ "configs/sam2/sam2_hiera_t.yaml",
+ "sam2_hiera_tiny.pt",
+ ),
+ "facebook/sam2-hiera-small": (
+ "configs/sam2/sam2_hiera_s.yaml",
+ "sam2_hiera_small.pt",
+ ),
+ "facebook/sam2-hiera-base-plus": (
+ "configs/sam2/sam2_hiera_b+.yaml",
+ "sam2_hiera_base_plus.pt",
+ ),
+ "facebook/sam2-hiera-large": (
+ "configs/sam2/sam2_hiera_l.yaml",
+ "sam2_hiera_large.pt",
+ ),
+ "facebook/sam2.1-hiera-tiny": (
+ "configs/sam2.1/sam2.1_hiera_t.yaml",
+ "sam2.1_hiera_tiny.pt",
+ ),
+ "facebook/sam2.1-hiera-small": (
+ "configs/sam2.1/sam2.1_hiera_s.yaml",
+ "sam2.1_hiera_small.pt",
+ ),
+ "facebook/sam2.1-hiera-base-plus": (
+ "configs/sam2.1/sam2.1_hiera_b+.yaml",
+ "sam2.1_hiera_base_plus.pt",
+ ),
+ "facebook/sam2.1-hiera-large": (
+ "configs/sam2.1/sam2.1_hiera_l.yaml",
+ "sam2.1_hiera_large.pt",
+ ),
+}
+
+
+def build_sam2(
+ config_file,
+ ckpt_path=None,
+ device="cuda",
+ mode="eval",
+ hydra_overrides_extra=[],
+ apply_postprocessing=True,
+ **kwargs,
+):
+
+ if apply_postprocessing:
+ hydra_overrides_extra = hydra_overrides_extra.copy()
+ hydra_overrides_extra += [
+ # dynamically fall back to multi-mask if the single mask is not stable
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
+ ]
+ # Read config and init model
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
+ OmegaConf.resolve(cfg)
+ model = instantiate(cfg.model, _recursive_=True)
+ _load_checkpoint(model, ckpt_path)
+ model = model.to(device)
+ if mode == "eval":
+ model.eval()
+ return model
+
+
+def build_sam2_video_predictor(
+ config_file,
+ ckpt_path=None,
+ device="cuda",
+ mode="eval",
+ hydra_overrides_extra=[],
+ apply_postprocessing=True,
+ vos_optimized=False,
+ **kwargs,
+):
+ hydra_overrides = [
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
+ ]
+ if vos_optimized:
+ hydra_overrides = [
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
+ "++model.compile_image_encoder=True", # Let sam2_base handle this
+ ]
+
+ if apply_postprocessing:
+ hydra_overrides_extra = hydra_overrides_extra.copy()
+ hydra_overrides_extra += [
+ # dynamically fall back to multi-mask if the single mask is not stable
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
+ "++model.fill_hole_area=8",
+ ]
+ hydra_overrides.extend(hydra_overrides_extra)
+
+ # Read config and init model
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
+ OmegaConf.resolve(cfg)
+ model = instantiate(cfg.model, _recursive_=True)
+ _load_checkpoint(model, ckpt_path)
+ model = model.to(device)
+ if mode == "eval":
+ model.eval()
+ return model
+
+
+def _hf_download(model_id):
+ from huggingface_hub import hf_hub_download
+
+ config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
+ ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
+ return config_name, ckpt_path
+
+
+def build_sam2_hf(model_id, **kwargs):
+ config_name, ckpt_path = _hf_download(model_id)
+ return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
+
+
+def build_sam2_video_predictor_hf(model_id, **kwargs):
+ config_name, ckpt_path = _hf_download(model_id)
+ return build_sam2_video_predictor(
+ config_file=config_name, ckpt_path=ckpt_path, **kwargs
+ )
+
+
+def _load_checkpoint(model, ckpt_path):
+ if ckpt_path is not None:
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
+ if missing_keys:
+ logging.error(missing_keys)
+ raise RuntimeError()
+ if unexpected_keys:
+ logging.error(unexpected_keys)
+ raise RuntimeError()
+ logging.info("Loaded checkpoint sucessfully")
diff --git a/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d7172f9b0b663aaaace97fed7e2a08db75150461
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
@@ -0,0 +1,116 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 112
+ num_heads: 2
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [896, 448, 224, 112]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml b/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..23073ea7a95901be656b3c6d1a66ce8736ab7ad3
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
@@ -0,0 +1,120 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 144
+ num_heads: 2
+ stages: [2, 6, 36, 4]
+ global_att_blocks: [23, 33, 43]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ window_spec: [8, 4, 16, 8]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [1152, 576, 288, 144]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fd8d40465b18b3de39b0a565aca712306306c4ed
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
@@ -0,0 +1,119 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 11, 2]
+ global_att_blocks: [7, 10, 13]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e762aec932f26436d13798f3feb3ec82c360a943
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
@@ -0,0 +1,121 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 7, 2]
+ global_att_blocks: [5, 7, 9]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ # SAM decoder
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ # HieraT does not currently support compilation, should always be set to False
+ compile_image_encoder: False
diff --git a/phantom/submodules/sam2/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml b/phantom/submodules/sam2/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9b6faa79f47ee576faf007bffd23fb6649bd881d
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
@@ -0,0 +1,339 @@
+# @package _global_
+
+scratch:
+ resolution: 1024
+ train_batch_size: 1
+ num_train_workers: 10
+ num_frames: 8
+ max_num_objects: 3
+ base_lr: 5.0e-6
+ vision_lr: 3.0e-06
+ phases_per_epoch: 1
+ num_epochs: 40
+
+dataset:
+ # PATHS to Dataset
+ img_folder: null # PATH to MOSE JPEGImages folder
+ gt_folder: null # PATH to MOSE Annotations folder
+ file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
+ multiplier: 2
+
+# Video transforms
+vos:
+ train_transforms:
+ - _target_: training.dataset.transforms.ComposeAPI
+ transforms:
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
+ consistent_transform: True
+ - _target_: training.dataset.transforms.RandomAffine
+ degrees: 25
+ shear: 20
+ image_interpolation: bilinear
+ consistent_transform: True
+ - _target_: training.dataset.transforms.RandomResizeAPI
+ sizes: ${scratch.resolution}
+ square: true
+ consistent_transform: True
+ - _target_: training.dataset.transforms.ColorJitter
+ consistent_transform: True
+ brightness: 0.1
+ contrast: 0.03
+ saturation: 0.03
+ hue: null
+ - _target_: training.dataset.transforms.RandomGrayscale
+ p: 0.05
+ consistent_transform: True
+ - _target_: training.dataset.transforms.ColorJitter
+ consistent_transform: False
+ brightness: 0.1
+ contrast: 0.05
+ saturation: 0.05
+ hue: null
+ - _target_: training.dataset.transforms.ToTensorAPI
+ - _target_: training.dataset.transforms.NormalizeAPI
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+
+trainer:
+ _target_: training.trainer.Trainer
+ mode: train_only
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
+ accelerator: cuda
+ seed_value: 123
+
+ model:
+ _target_: training.model.sam2.SAM2Train
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 112
+ num_heads: 2
+ drop_path_rate: 0.1
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [896, 448, 224, 112]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: ${scratch.resolution}
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ # compile_image_encoder: False
+
+ ####### Training specific params #######
+ # box/point input and corrections
+ prob_to_use_pt_input_for_train: 0.5
+ prob_to_use_pt_input_for_eval: 0.0
+ prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
+ prob_to_use_box_input_for_eval: 0.0
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
+ # maximum 2 initial conditioning frames
+ num_init_cond_frames_for_train: 2
+ rand_init_cond_frames_for_train: True # random 1~2
+ num_correction_pt_per_frame: 7
+ use_act_ckpt_iterative_pt_sampling: false
+
+
+
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
+ forward_backbone_per_frame_for_eval: True
+
+
+ data:
+ train:
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
+ phases_per_epoch: ${scratch.phases_per_epoch}
+ batch_sizes:
+ - ${scratch.train_batch_size}
+
+ datasets:
+ - _target_: training.dataset.utils.RepeatFactorWrapper
+ dataset:
+ _target_: training.dataset.utils.ConcatDataset
+ datasets:
+ - _target_: training.dataset.vos_dataset.VOSDataset
+ transforms: ${vos.train_transforms}
+ training: true
+ video_dataset:
+ _target_: training.dataset.vos_raw_dataset.PNGRawDataset
+ img_folder: ${dataset.img_folder}
+ gt_folder: ${dataset.gt_folder}
+ file_list_txt: ${dataset.file_list_txt}
+ sampler:
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
+ num_frames: ${scratch.num_frames}
+ max_num_objects: ${scratch.max_num_objects}
+ multiplier: ${dataset.multiplier}
+ shuffle: True
+ num_workers: ${scratch.num_train_workers}
+ pin_memory: True
+ drop_last: True
+ collate_fn:
+ _target_: training.utils.data_utils.collate_fn
+ _partial_: true
+ dict_key: all
+
+ optim:
+ amp:
+ enabled: True
+ amp_dtype: bfloat16
+
+ optimizer:
+ _target_: torch.optim.AdamW
+
+ gradient_clip:
+ _target_: training.optimizer.GradientClipper
+ max_norm: 0.1
+ norm_type: 2
+
+ param_group_modifiers:
+ - _target_: training.optimizer.layer_decay_param_modifier
+ _partial_: True
+ layer_decay_value: 0.9
+ apply_to: 'image_encoder.trunk'
+ overrides:
+ - pattern: '*pos_embed*'
+ value: 1.0
+
+ options:
+ lr:
+ - scheduler:
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
+ start_value: ${scratch.base_lr}
+ end_value: ${divide:${scratch.base_lr},10}
+ - scheduler:
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
+ start_value: ${scratch.vision_lr}
+ end_value: ${divide:${scratch.vision_lr},10}
+ param_names:
+ - 'image_encoder.*'
+ weight_decay:
+ - scheduler:
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
+ value: 0.1
+ - scheduler:
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
+ value: 0.0
+ param_names:
+ - '*bias*'
+ module_cls_names: ['torch.nn.LayerNorm']
+
+ loss:
+ all:
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
+ weight_dict:
+ loss_mask: 20
+ loss_dice: 1
+ loss_iou: 1
+ loss_class: 1
+ supervise_all_iou: true
+ iou_use_l1_loss: true
+ pred_obj_scores: true
+ focal_gamma_obj_score: 0.0
+ focal_alpha_obj_score: -1.0
+
+ distributed:
+ backend: nccl
+ find_unused_parameters: True
+
+ logging:
+ tensorboard_writer:
+ _target_: training.utils.logger.make_tensorboard_logger
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
+ flush_secs: 120
+ should_log: True
+ log_dir: ${launcher.experiment_log_dir}/logs
+ log_freq: 10
+
+ # initialize from a SAM 2 checkpoint
+ checkpoint:
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
+ save_freq: 0 # 0 only last checkpoint is saved.
+ model_weight_initializer:
+ _partial_: True
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
+ strict: True
+ ignore_unexpected_keys: null
+ ignore_missing_keys: null
+
+ state_dict:
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
+ ckpt_state_dict_keys: ['model']
+
+launcher:
+ num_nodes: 1
+ gpus_per_node: 8
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
+
+# SLURM args if running on a cluster
+submitit:
+ partition: null
+ account: null
+ qos: null
+ cpus_per_task: 10
+ use_cluster: false
+ timeout_hour: 24
+ name: null
+ port_range: [10000, 65000]
+
diff --git a/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml b/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0f435af02fc88e2d3b7bff06f8cf8013cc079c24
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_b+.yaml
@@ -0,0 +1,113 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 112
+ num_heads: 2
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [896, 448, 224, 112]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_l.yaml b/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1092802b1d24be6fedf78939f45b0d021d4ec560
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_l.yaml
@@ -0,0 +1,117 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 144
+ num_heads: 2
+ stages: [2, 6, 36, 4]
+ global_att_blocks: [23, 33, 43]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ window_spec: [8, 4, 16, 8]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [1152, 576, 288, 144]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_s.yaml b/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_s.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..174e414f1467d80e94a34e9525dc373058f8caaa
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_s.yaml
@@ -0,0 +1,116 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 11, 2]
+ global_att_blocks: [7, 10, 13]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_t.yaml b/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_t.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..121447aabd5318fac20efc2bc00d7c406ca26f01
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/configs/sam2/sam2_hiera_t.yaml
@@ -0,0 +1,118 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 7, 2]
+ global_att_blocks: [5, 7, 9]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [64, 64]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ # SAM decoder
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ # HieraT does not currently support compilation, should always be set to False
+ compile_image_encoder: False
diff --git a/phantom/submodules/sam2/sam2/csrc/connected_components.cu b/phantom/submodules/sam2/sam2/csrc/connected_components.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ced21eb32eaaadb818d441c1322b99d1bf068f45
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/csrc/connected_components.cu
@@ -0,0 +1,289 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+// adapted from https://github.com/zsef123/Connected_components_PyTorch
+// with license found in the LICENSE_cctorch file in the root directory.
+#include
+#include
+#include
+#include
+#include
+#include
+
+// 2d
+#define BLOCK_ROWS 16
+#define BLOCK_COLS 16
+
+namespace cc2d {
+
+template
+__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
+ return (bitmap >> pos) & 1;
+}
+
+__device__ int32_t find(const int32_t* s_buf, int32_t n) {
+ while (s_buf[n] != n)
+ n = s_buf[n];
+ return n;
+}
+
+__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
+ const int32_t id = n;
+ while (s_buf[n] != n) {
+ n = s_buf[n];
+ s_buf[id] = n;
+ }
+ return n;
+}
+
+__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
+ bool done;
+ do {
+ a = find(s_buf, a);
+ b = find(s_buf, b);
+
+ if (a < b) {
+ int32_t old = atomicMin(s_buf + b, a);
+ done = (old == b);
+ b = old;
+ } else if (b < a) {
+ int32_t old = atomicMin(s_buf + a, b);
+ done = (old == a);
+ a = old;
+ } else
+ done = true;
+
+ } while (!done);
+}
+
+__global__ void
+init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+ const uint32_t idx = row * W + col;
+
+ if (row < H && col < W)
+ label[idx] = idx;
+}
+
+__global__ void
+merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+ const uint32_t idx = row * W + col;
+
+ if (row >= H || col >= W)
+ return;
+
+ uint32_t P = 0;
+
+ if (img[idx])
+ P |= 0x777;
+ if (row + 1 < H && img[idx + W])
+ P |= 0x777 << 4;
+ if (col + 1 < W && img[idx + 1])
+ P |= 0x777 << 1;
+
+ if (col == 0)
+ P &= 0xEEEE;
+ if (col + 1 >= W)
+ P &= 0x3333;
+ else if (col + 2 >= W)
+ P &= 0x7777;
+
+ if (row == 0)
+ P &= 0xFFF0;
+ if (row + 1 >= H)
+ P &= 0xFF;
+
+ if (P > 0) {
+ // If need check about top-left pixel(if flag the first bit) and hit the
+ // top-left pixel
+ if (hasBit(P, 0) && img[idx - W - 1]) {
+ union_(label, idx, idx - 2 * W - 2); // top left block
+ }
+
+ if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
+ union_(label, idx, idx - 2 * W); // top bottom block
+
+ if (hasBit(P, 3) && img[idx + 2 - W])
+ union_(label, idx, idx - 2 * W + 2); // top right block
+
+ if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
+ union_(label, idx, idx - 2); // just left block
+ }
+}
+
+__global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+ const uint32_t idx = row * W + col;
+
+ if (row < H && col < W)
+ find_n_compress(label, idx);
+}
+
+__global__ void final_labeling(
+ const uint8_t* img,
+ int32_t* label,
+ const int32_t W,
+ const int32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+ const uint32_t idx = row * W + col;
+
+ if (row >= H || col >= W)
+ return;
+
+ int32_t y = label[idx] + 1;
+
+ if (img[idx])
+ label[idx] = y;
+ else
+ label[idx] = 0;
+
+ if (col + 1 < W) {
+ if (img[idx + 1])
+ label[idx + 1] = y;
+ else
+ label[idx + 1] = 0;
+
+ if (row + 1 < H) {
+ if (img[idx + W + 1])
+ label[idx + W + 1] = y;
+ else
+ label[idx + W + 1] = 0;
+ }
+ }
+
+ if (row + 1 < H) {
+ if (img[idx + W])
+ label[idx + W] = y;
+ else
+ label[idx + W] = 0;
+ }
+}
+
+__global__ void init_counting(
+ const int32_t* label,
+ int32_t* count_init,
+ const int32_t W,
+ const int32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
+ const uint32_t idx = row * W + col;
+
+ if (row >= H || col >= W)
+ return;
+
+ int32_t y = label[idx];
+ if (y > 0) {
+ int32_t count_idx = y - 1;
+ atomicAdd(count_init + count_idx, 1);
+ }
+}
+
+__global__ void final_counting(
+ const int32_t* label,
+ const int32_t* count_init,
+ int32_t* count_final,
+ const int32_t W,
+ const int32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
+ const uint32_t idx = row * W + col;
+
+ if (row >= H || col >= W)
+ return;
+
+ int32_t y = label[idx];
+ if (y > 0) {
+ int32_t count_idx = y - 1;
+ count_final[idx] = count_init[count_idx];
+ } else {
+ count_final[idx] = 0;
+ }
+}
+
+} // namespace cc2d
+
+std::vector get_connected_componnets(
+ const torch::Tensor& inputs) {
+ AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
+ AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
+ AT_ASSERTM(
+ inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
+
+ const uint32_t N = inputs.size(0);
+ const uint32_t C = inputs.size(1);
+ const uint32_t H = inputs.size(2);
+ const uint32_t W = inputs.size(3);
+
+ AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
+ AT_ASSERTM((H % 2) == 0, "height must be an even number");
+ AT_ASSERTM((W % 2) == 0, "width must be an even number");
+
+ // label must be uint32_t
+ auto label_options =
+ torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
+ torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
+ torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
+ torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
+
+ dim3 grid = dim3(
+ ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
+ ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
+ dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
+ dim3 grid_count =
+ dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
+ dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ for (int n = 0; n < N; n++) {
+ uint32_t offset = n * H * W;
+
+ cc2d::init_labeling<<>>(
+ labels.data_ptr() + offset, W, H);
+ cc2d::merge<<>>(
+ inputs.data_ptr() + offset,
+ labels.data_ptr() + offset,
+ W,
+ H);
+ cc2d::compression<<>>(
+ labels.data_ptr() + offset, W, H);
+ cc2d::final_labeling<<>>(
+ inputs.data_ptr() + offset,
+ labels.data_ptr() + offset,
+ W,
+ H);
+
+ // get the counting of each pixel
+ cc2d::init_counting<<>>(
+ labels.data_ptr() + offset,
+ counts_init.data_ptr() + offset,
+ W,
+ H);
+ cc2d::final_counting<<>>(
+ labels.data_ptr() + offset,
+ counts_init.data_ptr() + offset,
+ counts_final.data_ptr() + offset,
+ W,
+ H);
+ }
+
+ // returned values are [labels, counts]
+ std::vector outputs;
+ outputs.push_back(labels);
+ outputs.push_back(counts_final);
+ return outputs;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def(
+ "get_connected_componnets",
+ &get_connected_componnets,
+ "get_connected_componnets");
+}
diff --git a/phantom/submodules/sam2/sam2/modeling/__init__.py b/phantom/submodules/sam2/sam2/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/phantom/submodules/sam2/sam2/modeling/backbones/__init__.py b/phantom/submodules/sam2/sam2/modeling/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/backbones/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/phantom/submodules/sam2/sam2/modeling/backbones/hieradet.py b/phantom/submodules/sam2/sam2/modeling/backbones/hieradet.py
new file mode 100644
index 0000000000000000000000000000000000000000..19ac77b61d8e1345a301686d39ef2ab6e4b035fb
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/backbones/hieradet.py
@@ -0,0 +1,317 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from functools import partial
+from typing import List, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from iopath.common.file_io import g_pathmgr
+
+from sam2.modeling.backbones.utils import (
+ PatchEmbed,
+ window_partition,
+ window_unpartition,
+)
+
+from sam2.modeling.sam2_utils import DropPath, MLP
+
+
+def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
+ if pool is None:
+ return x
+ # (B, H, W, C) -> (B, C, H, W)
+ x = x.permute(0, 3, 1, 2)
+ x = pool(x)
+ # (B, C, H', W') -> (B, H', W', C)
+ x = x.permute(0, 2, 3, 1)
+ if norm:
+ x = norm(x)
+
+ return x
+
+
+class MultiScaleAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ q_pool: nn.Module = None,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.dim_out = dim_out
+ self.num_heads = num_heads
+ self.q_pool = q_pool
+ self.qkv = nn.Linear(dim, dim_out * 3)
+ self.proj = nn.Linear(dim_out, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, H, W, _ = x.shape
+ # qkv with shape (B, H * W, 3, nHead, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
+ # q, k, v with shape (B, H * W, nheads, C)
+ q, k, v = torch.unbind(qkv, 2)
+
+ # Q pooling (for downsample at stage changes)
+ if self.q_pool:
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
+ H, W = q.shape[1:3] # downsampled shape
+ q = q.reshape(B, H * W, self.num_heads, -1)
+
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
+ x = F.scaled_dot_product_attention(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ )
+ # Transpose back
+ x = x.transpose(1, 2)
+ x = x.reshape(B, H, W, -1)
+
+ x = self.proj(x)
+
+ return x
+
+
+class MultiScaleBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ drop_path: float = 0.0,
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
+ q_stride: Tuple[int, int] = None,
+ act_layer: nn.Module = nn.GELU,
+ window_size: int = 0,
+ ):
+ super().__init__()
+
+ if isinstance(norm_layer, str):
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
+
+ self.dim = dim
+ self.dim_out = dim_out
+ self.norm1 = norm_layer(dim)
+
+ self.window_size = window_size
+
+ self.pool, self.q_stride = None, q_stride
+ if self.q_stride:
+ self.pool = nn.MaxPool2d(
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
+ )
+
+ self.attn = MultiScaleAttention(
+ dim,
+ dim_out,
+ num_heads=num_heads,
+ q_pool=self.pool,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim_out)
+ self.mlp = MLP(
+ dim_out,
+ int(dim_out * mlp_ratio),
+ dim_out,
+ num_layers=2,
+ activation=act_layer,
+ )
+
+ if dim != dim_out:
+ self.proj = nn.Linear(dim, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x # B, H, W, C
+ x = self.norm1(x)
+
+ # Skip connection
+ if self.dim != self.dim_out:
+ shortcut = do_pool(self.proj(x), self.pool)
+
+ # Window partition
+ window_size = self.window_size
+ if window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, window_size)
+
+ # Window Attention + Q Pooling (if stage change)
+ x = self.attn(x)
+ if self.q_stride:
+ # Shapes have changed due to Q pooling
+ window_size = self.window_size // self.q_stride[0]
+ H, W = shortcut.shape[1:3]
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ pad_hw = (H + pad_h, W + pad_w)
+
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
+
+ x = shortcut + self.drop_path(x)
+ # MLP
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class Hiera(nn.Module):
+ """
+ Reference: https://arxiv.org/abs/2306.00989
+ """
+
+ def __init__(
+ self,
+ embed_dim: int = 96, # initial embed dim
+ num_heads: int = 1, # initial number of heads
+ drop_path_rate: float = 0.0, # stochastic depth
+ q_pool: int = 3, # number of q_pool stages
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
+ head_mul: float = 2.0, # head_mul factor at stage shift
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
+ # window size per stage, when not using global att.
+ window_spec: Tuple[int, ...] = (
+ 8,
+ 4,
+ 14,
+ 7,
+ ),
+ # global attn in these blocks
+ global_att_blocks: Tuple[int, ...] = (
+ 12,
+ 16,
+ 20,
+ ),
+ weights_path=None,
+ return_interm_layers=True, # return feats from every stage
+ ):
+ super().__init__()
+
+ assert len(stages) == len(window_spec)
+ self.window_spec = window_spec
+
+ depth = sum(stages)
+ self.q_stride = q_stride
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
+ self.return_interm_layers = return_interm_layers
+
+ self.patch_embed = PatchEmbed(
+ embed_dim=embed_dim,
+ )
+ # Which blocks have global att?
+ self.global_att_blocks = global_att_blocks
+
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
+ )
+ self.pos_embed_window = nn.Parameter(
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
+ )
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+
+ cur_stage = 1
+ self.blocks = nn.ModuleList()
+
+ for i in range(depth):
+ dim_out = embed_dim
+ # lags by a block, so first block of
+ # next stage uses an initial window size
+ # of previous stage and final window size of current stage
+ window_size = self.window_spec[cur_stage - 1]
+
+ if self.global_att_blocks is not None:
+ window_size = 0 if i in self.global_att_blocks else window_size
+
+ if i - 1 in self.stage_ends:
+ dim_out = int(embed_dim * dim_mul)
+ num_heads = int(num_heads * head_mul)
+ cur_stage += 1
+
+ block = MultiScaleBlock(
+ dim=embed_dim,
+ dim_out=dim_out,
+ num_heads=num_heads,
+ drop_path=dpr[i],
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
+ window_size=window_size,
+ )
+
+ embed_dim = dim_out
+ self.blocks.append(block)
+
+ self.channel_list = (
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
+ if return_interm_layers
+ else [self.blocks[-1].dim_out]
+ )
+
+ if weights_path is not None:
+ with g_pathmgr.open(weights_path, "rb") as f:
+ chkpt = torch.load(f, map_location="cpu")
+ logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
+
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
+ h, w = hw
+ window_embed = self.pos_embed_window
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+ pos_embed = pos_embed + window_embed.tile(
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
+ )
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
+ return pos_embed
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ x = self.patch_embed(x)
+ # x: (B, H, W, C)
+
+ # Add pos embed
+ x = x + self._get_pos_embed(x.shape[1:3])
+
+ outputs = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if (i == self.stage_ends[-1]) or (
+ i in self.stage_ends and self.return_interm_layers
+ ):
+ feats = x.permute(0, 3, 1, 2)
+ outputs.append(feats)
+
+ return outputs
+
+ def get_layer_id(self, layer_name):
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+ num_layers = self.get_num_layers()
+
+ if layer_name.find("rel_pos") != -1:
+ return num_layers + 1
+ elif layer_name.find("pos_embed") != -1:
+ return 0
+ elif layer_name.find("patch_embed") != -1:
+ return 0
+ elif layer_name.find("blocks") != -1:
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
+ else:
+ return num_layers + 1
+
+ def get_num_layers(self) -> int:
+ return len(self.blocks)
diff --git a/phantom/submodules/sam2/sam2/modeling/backbones/image_encoder.py b/phantom/submodules/sam2/sam2/modeling/backbones/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..37e9266bc98596e97ca303118c910ed24f6cee2c
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/backbones/image_encoder.py
@@ -0,0 +1,134 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ImageEncoder(nn.Module):
+ def __init__(
+ self,
+ trunk: nn.Module,
+ neck: nn.Module,
+ scalp: int = 0,
+ ):
+ super().__init__()
+ self.trunk = trunk
+ self.neck = neck
+ self.scalp = scalp
+ assert (
+ self.trunk.channel_list == self.neck.backbone_channel_list
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
+
+ def forward(self, sample: torch.Tensor):
+ # Forward through backbone
+ features, pos = self.neck(self.trunk(sample))
+ if self.scalp > 0:
+ # Discard the lowest resolution features
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
+
+ src = features[-1]
+ output = {
+ "vision_features": src,
+ "vision_pos_enc": pos,
+ "backbone_fpn": features,
+ }
+ return output
+
+
+class FpnNeck(nn.Module):
+ """
+ A modified variant of Feature Pyramid Network (FPN) neck
+ (we remove output conv and also do bicubic interpolation similar to ViT
+ pos embed interpolation)
+ """
+
+ def __init__(
+ self,
+ position_encoding: nn.Module,
+ d_model: int,
+ backbone_channel_list: List[int],
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ fpn_interp_model: str = "bilinear",
+ fuse_type: str = "sum",
+ fpn_top_down_levels: Optional[List[int]] = None,
+ ):
+ """Initialize the neck
+ :param trunk: the backbone
+ :param position_encoding: the positional encoding to use
+ :param d_model: the dimension of the model
+ :param neck_norm: the normalization to use
+ """
+ super().__init__()
+ self.position_encoding = position_encoding
+ self.convs = nn.ModuleList()
+ self.backbone_channel_list = backbone_channel_list
+ self.d_model = d_model
+ for dim in backbone_channel_list:
+ current = nn.Sequential()
+ current.add_module(
+ "conv",
+ nn.Conv2d(
+ in_channels=dim,
+ out_channels=d_model,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ ),
+ )
+
+ self.convs.append(current)
+ self.fpn_interp_model = fpn_interp_model
+ assert fuse_type in ["sum", "avg"]
+ self.fuse_type = fuse_type
+
+ # levels to have top-down features in its outputs
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
+ # have top-down propagation, while outputs of level 0 and level 1 have only
+ # lateral features from the same backbone level.
+ if fpn_top_down_levels is None:
+ # default is to have top-down features on all levels
+ fpn_top_down_levels = range(len(self.convs))
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
+
+ def forward(self, xs: List[torch.Tensor]):
+
+ out = [None] * len(self.convs)
+ pos = [None] * len(self.convs)
+ assert len(xs) == len(self.convs)
+ # fpn forward pass
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
+ prev_features = None
+ # forward in top-down order (from low to high resolution)
+ n = len(self.convs) - 1
+ for i in range(n, -1, -1):
+ x = xs[i]
+ lateral_features = self.convs[n - i](x)
+ if i in self.fpn_top_down_levels and prev_features is not None:
+ top_down_features = F.interpolate(
+ prev_features.to(dtype=torch.float32),
+ scale_factor=2.0,
+ mode=self.fpn_interp_model,
+ align_corners=(
+ None if self.fpn_interp_model == "nearest" else False
+ ),
+ antialias=False,
+ )
+ prev_features = lateral_features + top_down_features
+ if self.fuse_type == "avg":
+ prev_features /= 2
+ else:
+ prev_features = lateral_features
+ x_out = prev_features
+ out[i] = x_out
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
+
+ return out, pos
diff --git a/phantom/submodules/sam2/sam2/modeling/backbones/utils.py b/phantom/submodules/sam2/sam2/modeling/backbones/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..930b1b7622e7b0e7270120dcafccc242ef0f4f28
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/backbones/utils.py
@@ -0,0 +1,93 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Some utilities for backbones, in particular for windowing"""
+
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def window_partition(x, window_size):
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(windows, window_size, pad_hw, hw):
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.reshape(
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
+ )
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :]
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, ...] = (7, 7),
+ stride: Tuple[int, ...] = (4, 4),
+ padding: Tuple[int, ...] = (3, 3),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ):
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
diff --git a/phantom/submodules/sam2/sam2/modeling/memory_attention.py b/phantom/submodules/sam2/sam2/modeling/memory_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b07f9d87e3d8194ca5e11fc20f01604d591a59d
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/memory_attention.py
@@ -0,0 +1,169 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+
+from sam2.modeling.sam.transformer import RoPEAttention
+
+from sam2.modeling.sam2_utils import get_activation_fn, get_clones
+
+
+class MemoryAttentionLayer(nn.Module):
+
+ def __init__(
+ self,
+ activation: str,
+ cross_attention: nn.Module,
+ d_model: int,
+ dim_feedforward: int,
+ dropout: float,
+ pos_enc_at_attn: bool,
+ pos_enc_at_cross_attn_keys: bool,
+ pos_enc_at_cross_attn_queries: bool,
+ self_attention: nn.Module,
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.dim_feedforward = dim_feedforward
+ self.dropout_value = dropout
+ self.self_attn = self_attention
+ self.cross_attn_image = cross_attention
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation_str = activation
+ self.activation = get_activation_fn(activation)
+
+ # Where to add pos enc
+ self.pos_enc_at_attn = pos_enc_at_attn
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
+
+ def _forward_sa(self, tgt, query_pos):
+ # Self-Attention
+ tgt2 = self.norm1(tgt)
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
+ tgt2 = self.self_attn(q, k, v=tgt2)
+ tgt = tgt + self.dropout1(tgt2)
+ return tgt
+
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
+ kwds = {}
+ if num_k_exclude_rope > 0:
+ assert isinstance(self.cross_attn_image, RoPEAttention)
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
+
+ # Cross-Attention
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.cross_attn_image(
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
+ v=memory,
+ **kwds,
+ )
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ num_k_exclude_rope: int = 0,
+ ) -> torch.Tensor:
+
+ # Self-Attn, Cross-Attn
+ tgt = self._forward_sa(tgt, query_pos)
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
+ # MLP
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+
+class MemoryAttention(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ pos_enc_at_input: bool,
+ layer: nn.Module,
+ num_layers: int,
+ batch_first: bool = True, # Do layers expect batch first input?
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.layers = get_clones(layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = nn.LayerNorm(d_model)
+ self.pos_enc_at_input = pos_enc_at_input
+ self.batch_first = batch_first
+
+ def forward(
+ self,
+ curr: torch.Tensor, # self-attention inputs
+ memory: torch.Tensor, # cross-attention inputs
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
+ ):
+ if isinstance(curr, list):
+ assert isinstance(curr_pos, list)
+ assert len(curr) == len(curr_pos) == 1
+ curr, curr_pos = (
+ curr[0],
+ curr_pos[0],
+ )
+
+ assert (
+ curr.shape[1] == memory.shape[1]
+ ), "Batch size must be the same for curr and memory"
+
+ output = curr
+ if self.pos_enc_at_input and curr_pos is not None:
+ output = output + 0.1 * curr_pos
+
+ if self.batch_first:
+ # Convert to batch first
+ output = output.transpose(0, 1)
+ curr_pos = curr_pos.transpose(0, 1)
+ memory = memory.transpose(0, 1)
+ memory_pos = memory_pos.transpose(0, 1)
+
+ for layer in self.layers:
+ kwds = {}
+ if isinstance(layer.cross_attn_image, RoPEAttention):
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
+
+ output = layer(
+ tgt=output,
+ memory=memory,
+ pos=memory_pos,
+ query_pos=curr_pos,
+ **kwds,
+ )
+ normed_output = self.norm(output)
+
+ if self.batch_first:
+ # Convert back to seq first
+ normed_output = normed_output.transpose(0, 1)
+ curr_pos = curr_pos.transpose(0, 1)
+
+ return normed_output
diff --git a/phantom/submodules/sam2/sam2/modeling/memory_encoder.py b/phantom/submodules/sam2/sam2/modeling/memory_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60202dfaba87232c3870fb2101b5322a119d985
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/memory_encoder.py
@@ -0,0 +1,181 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
+
+
+class MaskDownSampler(nn.Module):
+ """
+ Progressively downsample a mask by total_stride, each time by stride.
+ Note that LayerNorm is applied per *token*, like in ViT.
+
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
+ In the end, we linearly project to embed_dim channels.
+ """
+
+ def __init__(
+ self,
+ embed_dim=256,
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ total_stride=16,
+ activation=nn.GELU,
+ ):
+ super().__init__()
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
+ assert stride**num_layers == total_stride
+ self.encoder = nn.Sequential()
+ mask_in_chans, mask_out_chans = 1, 1
+ for _ in range(num_layers):
+ mask_out_chans = mask_in_chans * (stride**2)
+ self.encoder.append(
+ nn.Conv2d(
+ mask_in_chans,
+ mask_out_chans,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+ )
+ self.encoder.append(LayerNorm2d(mask_out_chans))
+ self.encoder.append(activation())
+ mask_in_chans = mask_out_chans
+
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
+
+ def forward(self, x):
+ return self.encoder(x)
+
+
+# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
+class CXBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+
+ def __init__(
+ self,
+ dim,
+ kernel_size=7,
+ padding=3,
+ drop_path=0.0,
+ layer_scale_init_value=1e-6,
+ use_dwconv=True,
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=dim if use_dwconv else 1,
+ ) # depthwise conv
+ self.norm = LayerNorm2d(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, 4 * dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = self.norm(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+class Fuser(nn.Module):
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
+ super().__init__()
+ self.proj = nn.Identity()
+ self.layers = get_clones(layer, num_layers)
+
+ if input_projection:
+ assert dim is not None
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
+
+ def forward(self, x):
+ # normally x: (N, C, H, W)
+ x = self.proj(x)
+ for layer in self.layers:
+ x = layer(x)
+ return x
+
+
+class MemoryEncoder(nn.Module):
+ def __init__(
+ self,
+ out_dim,
+ mask_downsampler,
+ fuser,
+ position_encoding,
+ in_dim=256, # in_dim of pix_feats
+ ):
+ super().__init__()
+
+ self.mask_downsampler = mask_downsampler
+
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
+ self.fuser = fuser
+ self.position_encoding = position_encoding
+ self.out_proj = nn.Identity()
+ if out_dim != in_dim:
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
+
+ def forward(
+ self,
+ pix_feat: torch.Tensor,
+ masks: torch.Tensor,
+ skip_mask_sigmoid: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ ## Process masks
+ # sigmoid, so that less domain shift from gt masks which are bool
+ if not skip_mask_sigmoid:
+ masks = F.sigmoid(masks)
+ masks = self.mask_downsampler(masks)
+
+ ## Fuse pix_feats and downsampled masks
+ # in case the visual features are on CPU, cast them to CUDA
+ pix_feat = pix_feat.to(masks.device)
+
+ x = self.pix_feat_proj(pix_feat)
+ x = x + masks
+ x = self.fuser(x)
+ x = self.out_proj(x)
+
+ pos = self.position_encoding(x).to(x.dtype)
+
+ return {"vision_features": x, "vision_pos_enc": [pos]}
diff --git a/phantom/submodules/sam2/sam2/modeling/position_encoding.py b/phantom/submodules/sam2/sam2/modeling/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..2241d4cf1a4495b4c67dc35cbed1c606357b9b7a
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/position_encoding.py
@@ -0,0 +1,239 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Any, Optional, Tuple
+
+import numpy as np
+
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention Is All You Need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self,
+ num_pos_feats,
+ temperature: int = 10000,
+ normalize: bool = True,
+ scale: Optional[float] = None,
+ # Following settings only relevant
+ # for warmping up cache for compilation
+ warmup_cache: bool = True,
+ image_size: int = 1024,
+ strides: Tuple[int] = (4, 8, 16, 32),
+ ):
+ super().__init__()
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
+ self.num_pos_feats = num_pos_feats // 2
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ self.cache = {}
+ if warmup_cache and torch.cuda.is_available():
+ # Warmup cache for cuda, to help with compilation
+ device = torch.device("cuda")
+ for stride in strides:
+ cache_key = (image_size // stride, image_size // stride)
+ self._pe(1, device, *cache_key)
+
+ def _encode_xy(self, x, y):
+ # The positions are expected to be normalized
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
+ x_embed = x * self.scale
+ y_embed = y * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, None] / dim_t
+ pos_y = y_embed[:, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
+ ).flatten(1)
+ pos_y = torch.stack(
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
+ ).flatten(1)
+ return pos_x, pos_y
+
+ @torch.no_grad()
+ def encode_boxes(self, x, y, w, h):
+ pos_x, pos_y = self._encode_xy(x, y)
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
+ return pos
+
+ encode = encode_boxes # Backwards compatibility
+
+ @torch.no_grad()
+ def encode_points(self, x, y, labels):
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
+ assert bx == by and nx == ny and bx == bl and nx == nl
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
+ return pos
+
+ @torch.no_grad()
+ def _pe(self, B, device, *cache_key):
+ H, W = cache_key
+ if cache_key in self.cache:
+ return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
+
+ y_embed = (
+ torch.arange(1, H + 1, dtype=torch.float32, device=device)
+ .view(1, -1, 1)
+ .repeat(B, 1, W)
+ )
+ x_embed = (
+ torch.arange(1, W + 1, dtype=torch.float32, device=device)
+ .view(1, 1, -1)
+ .repeat(B, H, 1)
+ )
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ self.cache[cache_key] = pos[0]
+ return pos
+
+ @torch.no_grad()
+ def forward(self, x: torch.Tensor):
+ B = x.shape[0]
+ cache_key = (x.shape[-2], x.shape[-1])
+ return self._pe(B, x.device, *cache_key)
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+ """
+
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer(
+ "positional_encoding_gaussian_matrix",
+ scale * torch.randn((2, num_pos_feats)),
+ )
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+ ) -> torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
+
+
+# Rotary Positional Encoding, adapted from:
+# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# 2. https://github.com/naver-ai/rope-vit
+# 3. https://github.com/lucidrains/rotary-embedding-torch
+
+
+def init_t_xy(end_x: int, end_y: int):
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
+ t_x = (t % end_x).float()
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
+ return t_x, t_y
+
+
+def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+ t_x, t_y = init_t_xy(end_x, end_y)
+ freqs_x = torch.outer(t_x, freqs_x)
+ freqs_y = torch.outer(t_y, freqs_y)
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_enc(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ repeat_freqs_k: bool = False,
+):
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = (
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ if xk.shape[-2] != 0
+ else None
+ )
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ if xk_ is None:
+ # no keys to rotate, due to dropout
+ return xq_out.type_as(xq).to(xq.device), xk
+ # repeat freqs along seq_len dim to match k seq_len
+ if repeat_freqs_k:
+ r = xk_.shape[-2] // xq_.shape[-2]
+ if freqs_cis.is_cuda:
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
+ else:
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
diff --git a/phantom/submodules/sam2/sam2/modeling/sam/__init__.py b/phantom/submodules/sam2/sam2/modeling/sam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/sam/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/phantom/submodules/sam2/sam2/modeling/sam/mask_decoder.py b/phantom/submodules/sam2/sam2/modeling/sam/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bebc0366b2703ffcb80a44bfd19cce8339b4fed
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/sam/mask_decoder.py
@@ -0,0 +1,295 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from sam2.modeling.sam2_utils import LayerNorm2d, MLP
+
+
+class MaskDecoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ use_high_res_features: bool = False,
+ iou_prediction_use_sigmoid=False,
+ dynamic_multimask_via_stability=False,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ pred_obj_scores: bool = False,
+ pred_obj_scores_mlp: bool = False,
+ use_multimask_token_for_obj_ptr: bool = False,
+ ) -> None:
+ """
+ Predicts masks given an image and prompt embeddings, using a
+ transformer architecture.
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.pred_obj_scores = pred_obj_scores
+ if self.pred_obj_scores:
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
+ ),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
+ ),
+ activation(),
+ )
+ self.use_high_res_features = use_high_res_features
+ if use_high_res_features:
+ self.conv_s0 = nn.Conv2d(
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
+ )
+ self.conv_s1 = nn.Conv2d(
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
+ )
+
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+ for i in range(self.num_mask_tokens)
+ ]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim,
+ iou_head_hidden_dim,
+ self.num_mask_tokens,
+ iou_head_depth,
+ sigmoid_output=iou_prediction_use_sigmoid,
+ )
+ if self.pred_obj_scores:
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
+ if pred_obj_scores_mlp:
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
+
+ # When outputting a single mask, optionally we can dynamically fall back to the best
+ # multimask output token if the single mask output token gives low stability scores.
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Arguments:
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ torch.Tensor: batched predicted masks
+ torch.Tensor: batched predictions of mask quality
+ torch.Tensor: batched SAM token for mask output
+ """
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ repeat_image=repeat_image,
+ high_res_features=high_res_features,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ masks = masks[:, 1:, :, :]
+ iou_pred = iou_pred[:, 1:]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ masks = masks[:, 0:1, :, :]
+ iou_pred = iou_pred[:, 0:1]
+
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
+ else:
+ # Take the mask output token. Here we *always* use the token for single mask output.
+ # At test time, even if we track after 1-click (and using multimask_output=True),
+ # we still take the single mask token here. The rationale is that we always track
+ # after multiple clicks during training, so the past tokens seen during training
+ # are always the single mask token (and we'll let it be the object-memory token).
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
+
+ # Prepare output
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts masks. See 'forward' for more details."""
+ # Concatenate output tokens
+ s = 0
+ if self.pred_obj_scores:
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ s = 1
+ else:
+ output_tokens = torch.cat(
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
+ )
+ output_tokens = output_tokens.unsqueeze(0).expand(
+ sparse_prompt_embeddings.size(0), -1, -1
+ )
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ if repeat_image:
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ else:
+ assert image_embeddings.shape[0] == tokens.shape[0]
+ src = image_embeddings
+ src = src + dense_prompt_embeddings
+ assert (
+ image_pe.size(0) == 1
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, s, :]
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ if not self.use_high_res_features:
+ upscaled_embedding = self.output_upscaling(src)
+ else:
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
+ feat_s0, feat_s1 = high_res_features
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
+ )
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ if self.pred_obj_scores:
+ assert s == 1
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
+ else:
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
+
+ return masks, iou_pred, mask_tokens_out, object_score_logits
+
+ def _get_stability_scores(self, mask_logits):
+ """
+ Compute stability scores of the mask logits based on the IoU between upper and
+ lower thresholds.
+ """
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ When outputting a single mask, if the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, we instead select from
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
+ batch_inds = torch.arange(
+ multimask_iou_scores.size(0), device=all_iou_scores.device
+ )
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
diff --git a/phantom/submodules/sam2/sam2/modeling/sam/prompt_encoder.py b/phantom/submodules/sam2/sam2/modeling/sam/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57876264b51f8c5236867359350e32d590efcb5
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/sam/prompt_encoder.py
@@ -0,0 +1,202 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from sam2.modeling.position_encoding import PositionEmbeddingRandom
+
+from sam2.modeling.sam2_utils import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ mask_in_chans: int,
+ activation: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ """
+ Encodes prompts for input to SAM's mask decoder.
+
+ Arguments:
+ embed_dim (int): The prompts' embedding dimension
+ image_embedding_size (tuple(int, int)): The spatial size of the
+ image embedding, as (H, W).
+ input_image_size (int): The padded size of the image as input
+ to the image encoder, as (H, W).
+ mask_in_chans (int): The number of hidden channels used for
+ encoding input masks.
+ activation (nn.Module): The activation to use when encoding
+ input masks.
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
+ point_embeddings = [
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
+ ]
+ self.point_embeddings = nn.ModuleList(point_embeddings)
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+ self.mask_input_size = (
+ 4 * image_embedding_size[0],
+ 4 * image_embedding_size[1],
+ )
+ self.mask_downscaling = nn.Sequential(
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans // 4),
+ activation(),
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans),
+ activation(),
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+ )
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+ def get_dense_pe(self) -> torch.Tensor:
+ """
+ Returns the positional encoding used to encode point prompts,
+ applied to a dense set of points the shape of the image encoding.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
+ """
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+ def _embed_points(
+ self,
+ points: torch.Tensor,
+ labels: torch.Tensor,
+ pad: bool,
+ ) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+ points = torch.cat([points, padding_point], dim=1)
+ labels = torch.cat([labels, padding_label], dim=1)
+ point_embedding = self.pe_layer.forward_with_coords(
+ points, self.input_image_size
+ )
+
+ point_embedding = torch.where(
+ (labels == -1).unsqueeze(-1),
+ torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
+ point_embedding,
+ )
+ point_embedding = torch.where(
+ (labels == 0).unsqueeze(-1),
+ point_embedding + self.point_embeddings[0].weight,
+ point_embedding,
+ )
+ point_embedding = torch.where(
+ (labels == 1).unsqueeze(-1),
+ point_embedding + self.point_embeddings[1].weight,
+ point_embedding,
+ )
+ point_embedding = torch.where(
+ (labels == 2).unsqueeze(-1),
+ point_embedding + self.point_embeddings[2].weight,
+ point_embedding,
+ )
+ point_embedding = torch.where(
+ (labels == 3).unsqueeze(-1),
+ point_embedding + self.point_embeddings[3].weight,
+ point_embedding,
+ )
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ coords = boxes.reshape(-1, 2, 2)
+ corner_embedding = self.pe_layer.forward_with_coords(
+ coords, self.input_image_size
+ )
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+ return corner_embedding
+
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+ """Embeds mask inputs."""
+ mask_embedding = self.mask_downscaling(masks)
+ return mask_embedding
+
+ def _get_batch_size(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> int:
+ """
+ Gets the batch size of the output given the batch size of the input prompts.
+ """
+ if points is not None:
+ return points[0].shape[0]
+ elif boxes is not None:
+ return boxes.shape[0]
+ elif masks is not None:
+ return masks.shape[0]
+ else:
+ return 1
+
+ def _get_device(self) -> torch.device:
+ return self.point_embeddings[0].weight.device
+
+ def forward(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense
+ embeddings.
+
+ Arguments:
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+ and labels to embed.
+ boxes (torch.Tensor or none): boxes to embed
+ masks (torch.Tensor or none): masks to embed
+
+ Returns:
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
+ BxNx(embed_dim), where N is determined by the number of input points
+ and boxes.
+ torch.Tensor: dense embeddings for the masks, in the shape
+ Bx(embed_dim)x(embed_H)x(embed_W)
+ """
+ bs = self._get_batch_size(points, boxes, masks)
+ sparse_embeddings = torch.empty(
+ (bs, 0, self.embed_dim), device=self._get_device()
+ )
+ if points is not None:
+ coords, labels = points
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+ if boxes is not None:
+ box_embeddings = self._embed_boxes(boxes)
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+ if masks is not None:
+ dense_embeddings = self._embed_masks(masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
diff --git a/phantom/submodules/sam2/sam2/modeling/sam/transformer.py b/phantom/submodules/sam2/sam2/modeling/sam/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9fe9a3fbc5cce4f1abe8ee0ae3a8602bbe2ff1b
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/sam/transformer.py
@@ -0,0 +1,311 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from functools import partial
+from typing import Tuple, Type
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
+from sam2.modeling.sam2_utils import MLP
+
+
+class TwoWayTransformer(nn.Module):
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ A transformer decoder that attends to an input image using
+ queries whose positional embedding is supplied.
+
+ Args:
+ depth (int): number of layers in the transformer
+ embedding_dim (int): the channel dimension for the input embeddings
+ num_heads (int): the number of heads for multihead attention. Must
+ divide embedding_dim
+ mlp_dim (int): the channel dimension internal to the MLP block
+ activation (nn.Module): the activation to use in the MLP block
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+ self.final_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self,
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ image_embedding (torch.Tensor): image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding (torch.Tensor): the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding
+ torch.Tensor: the processed image_embedding
+ """
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ bs, c, h, w = image_embedding.shape
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ )
+
+ # Apply the final attention layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
+ inputs.
+
+ Arguments:
+ embedding_dim (int): the channel dimension of the embeddings
+ num_heads (int): the number of heads in the attention layers
+ mlp_dim (int): the hidden dimension of the mlp block
+ activation (nn.Module): the activation of the mlp block
+ skip_first_layer_pe (bool): skip the PE on the first layer
+ """
+ super().__init__()
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLP(
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
+ )
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ else:
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class Attention(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ dropout: float = 0.0,
+ kv_in_dim: int = None,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert (
+ self.internal_dim % num_heads == 0
+ ), "num_heads must divide embedding_dim."
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ self.dropout_p = dropout
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ dropout_p = self.dropout_p if self.training else 0.0
+ # Attention
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
+
+
+class RoPEAttention(Attention):
+ """Attention with rotary position encoding."""
+
+ def __init__(
+ self,
+ *args,
+ rope_theta=10000.0,
+ # whether to repeat q rope to match k length
+ # this is needed for cross-attention to memories
+ rope_k_repeat=False,
+ feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.compute_cis = partial(
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
+ )
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
+ self.freqs_cis = (
+ freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
+ )
+ self.rope_k_repeat = rope_k_repeat
+
+ def forward(
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
+ ) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Apply rotary position encoding
+ w = h = math.sqrt(q.shape[-2])
+ self.freqs_cis = self.freqs_cis.to(q.device)
+ if self.freqs_cis.shape[0] != q.shape[-2]:
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
+ if q.shape[-2] != k.shape[-2]:
+ assert self.rope_k_repeat
+
+ num_k_rope = k.size(-2) - num_k_exclude_rope
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
+ q,
+ k[:, :, :num_k_rope],
+ freqs_cis=self.freqs_cis,
+ repeat_freqs_k=self.rope_k_repeat,
+ )
+
+ dropout_p = self.dropout_p if self.training else 0.0
+ # Attention
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
diff --git a/phantom/submodules/sam2/sam2/modeling/sam2_base.py b/phantom/submodules/sam2/sam2/modeling/sam2_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9f4e515b0d161942bf2bb64560056b3efbe6dac
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/sam2_base.py
@@ -0,0 +1,909 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.distributed
+import torch.nn.functional as F
+
+from torch.nn.init import trunc_normal_
+
+from sam2.modeling.sam.mask_decoder import MaskDecoder
+from sam2.modeling.sam.prompt_encoder import PromptEncoder
+from sam2.modeling.sam.transformer import TwoWayTransformer
+from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+
+
+class SAM2Base(torch.nn.Module):
+ def __init__(
+ self,
+ image_encoder,
+ memory_attention,
+ memory_encoder,
+ num_maskmem=7, # default 1 input frame + 6 previous frames
+ image_size=512,
+ backbone_stride=16, # stride of the image backbone output
+ sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
+ sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
+ # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
+ binarize_mask_from_pts_for_mem_enc=False,
+ use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
+ # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
+ # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
+ # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
+ max_cond_frames_in_attn=-1,
+ # on the first frame, whether to directly add the no-memory embedding to the image feature
+ # (instead of using the transformer encoder)
+ directly_add_no_mem_embed=False,
+ # whether to use high-resolution feature maps in the SAM mask decoder
+ use_high_res_features_in_sam=False,
+ # whether to output multiple (3) masks for the first click on initial conditioning frames
+ multimask_output_in_sam=False,
+ # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
+ # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
+ multimask_min_pt_num=1,
+ multimask_max_pt_num=1,
+ # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
+ multimask_output_for_tracking=False,
+ # Whether to use multimask tokens for obj ptr; Only relevant when both
+ # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
+ use_multimask_token_for_obj_ptr: bool = False,
+ # whether to use sigmoid to restrict ious prediction to [0-1]
+ iou_prediction_use_sigmoid=False,
+ # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
+ # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
+ # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
+ memory_temporal_stride_for_eval=1,
+ # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
+ non_overlap_masks_for_mem_enc=False,
+ # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder=False,
+ # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
+ max_obj_ptrs_in_encoder=16,
+ # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
+ add_tpos_enc_to_obj_ptrs=True,
+ # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
+ # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
+ proj_tpos_enc_in_obj_ptrs=False,
+ # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
+ # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
+ use_signed_tpos_enc_to_obj_ptrs=False,
+ # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
+ # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
+ only_obj_ptrs_in_the_past_for_eval=False,
+ # Whether to predict if there is an object in the frame
+ pred_obj_scores: bool = False,
+ # Whether to use an MLP to predict object scores
+ pred_obj_scores_mlp: bool = False,
+ # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
+ # Whether to have a fixed no obj pointer when there is no object present
+ # or to use it as an additive embedding with obj_ptr produced by decoder
+ fixed_no_obj_ptr: bool = False,
+ # Soft no object, i.e. mix in no_obj_ptr softly,
+ # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
+ soft_no_obj_ptr: bool = False,
+ use_mlp_for_obj_ptr_proj: bool = False,
+ # add no obj embedding to spatial frames
+ no_obj_embed_spatial: bool = False,
+ # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
+ sam_mask_decoder_extra_args=None,
+ compile_image_encoder: bool = False,
+ ):
+ super().__init__()
+
+ # Part 1: the image backbone
+ self.image_encoder = image_encoder
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
+ if use_obj_ptrs_in_encoder:
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
+ if proj_tpos_enc_in_obj_ptrs:
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
+ self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
+
+ # Part 2: memory attention to condition current frame's visual features
+ # with memories (and obj ptrs) from past frames
+ self.memory_attention = memory_attention
+ self.hidden_dim = image_encoder.neck.d_model
+
+ # Part 3: memory encoder for the previous frame's outputs
+ self.memory_encoder = memory_encoder
+ self.mem_dim = self.hidden_dim
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(
+ self.memory_encoder.out_proj, "weight"
+ ):
+ # if there is compression of memories along channel dim
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
+ self.num_maskmem = num_maskmem # Number of memories accessible
+ # Temporal encoding of the memories
+ self.maskmem_tpos_enc = torch.nn.Parameter(
+ torch.zeros(num_maskmem, 1, 1, self.mem_dim)
+ )
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
+ # a single token to indicate no memory embedding from previous frames
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ trunc_normal_(self.no_mem_embed, std=0.02)
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
+ # Apply sigmoid to the output raw mask logits (to turn them from
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
+ # On frames with mask input, whether to directly output the input mask without
+ # using a SAM prompt encoder + mask decoder
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
+ self.multimask_output_in_sam = multimask_output_in_sam
+ self.multimask_min_pt_num = multimask_min_pt_num
+ self.multimask_max_pt_num = multimask_max_pt_num
+ self.multimask_output_for_tracking = multimask_output_for_tracking
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
+
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
+ # and SAM-style mask decoder for the final mask output
+ self.image_size = image_size
+ self.backbone_stride = backbone_stride
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
+ self.pred_obj_scores = pred_obj_scores
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
+ self.soft_no_obj_ptr = soft_no_obj_ptr
+ if self.fixed_no_obj_ptr:
+ assert self.pred_obj_scores
+ assert self.use_obj_ptrs_in_encoder
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+ trunc_normal_(self.no_obj_ptr, std=0.02)
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
+ self.no_obj_embed_spatial = None
+ if no_obj_embed_spatial:
+ self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
+ trunc_normal_(self.no_obj_embed_spatial, std=0.02)
+
+ self._build_sam_heads()
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
+
+ # Model compilation
+ if compile_image_encoder:
+ # Compile the forward function (not the full module) to allow loading checkpoints.
+ print(
+ "Image encoder compilation is enabled. First forward pass will be slow."
+ )
+ self.image_encoder.forward = torch.compile(
+ self.image_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False,
+ )
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
+ "See notebooks/video_predictor_example.ipynb for an inference example."
+ )
+
+ def _build_sam_heads(self):
+ """Build SAM-style prompt encoder and mask decoder."""
+ self.sam_prompt_embed_dim = self.hidden_dim
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
+
+ # build PromptEncoder and MaskDecoder from SAM
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
+ self.sam_prompt_encoder = PromptEncoder(
+ embed_dim=self.sam_prompt_embed_dim,
+ image_embedding_size=(
+ self.sam_image_embedding_size,
+ self.sam_image_embedding_size,
+ ),
+ input_image_size=(self.image_size, self.image_size),
+ mask_in_chans=16,
+ )
+ self.sam_mask_decoder = MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=self.sam_prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=self.sam_prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ use_high_res_features=self.use_high_res_features_in_sam,
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+ pred_obj_scores=self.pred_obj_scores,
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+ **(self.sam_mask_decoder_extra_args or {}),
+ )
+ if self.use_obj_ptrs_in_encoder:
+ # a linear projection on SAM output tokens to turn them into object pointers
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
+ if self.use_mlp_for_obj_ptr_proj:
+ self.obj_ptr_proj = MLP(
+ self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
+ )
+ else:
+ self.obj_ptr_proj = torch.nn.Identity()
+ if self.proj_tpos_enc_in_obj_ptrs:
+ # a linear projection on temporal positional encoding in object pointers to
+ # avoid potential interference with spatial positional encoding
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+ else:
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
+
+ def _forward_sam_heads(
+ self,
+ backbone_features,
+ point_inputs=None,
+ mask_inputs=None,
+ high_res_features=None,
+ multimask_output=False,
+ ):
+ """
+ Forward SAM prompt encoders and mask heads.
+
+ Inputs:
+ - backbone_features: image features of [B, C, H, W] shape
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
+ absolute pixel-unit coordinate in (x, y) format of the P input points
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
+ positive clicks, 0 means negative clicks, and -1 means padding
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
+ same spatial size as the image.
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
+ which will be used as high-resolution feature maps for SAM decoder.
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
+ its corresponding IoU estimate.
+
+ Outputs:
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
+ the resolution (1/4 stride) of the input backbone_features.
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
+ upsampled from the low-resolution masks, with shape size as the image
+ (stride is 1 pixel).
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
+ if `multimask_output=False`), the estimated IoU of each output mask.
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
+ based on the output token from the SAM mask decoder.
+ """
+ B = backbone_features.size(0)
+ device = backbone_features.device
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
+ assert backbone_features.size(2) == self.sam_image_embedding_size
+ assert backbone_features.size(3) == self.sam_image_embedding_size
+
+ # a) Handle point prompts
+ if point_inputs is not None:
+ sam_point_coords = point_inputs["point_coords"]
+ sam_point_labels = point_inputs["point_labels"]
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+ else:
+ # If no points are provide, pad with an empty point (with label -1)
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+
+ # b) Handle mask prompts
+ if mask_inputs is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+ sam_mask_prompt = F.interpolate(
+ mask_inputs.float(),
+ size=self.sam_prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ else:
+ sam_mask_prompt = mask_inputs
+ else:
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
+ # a learned `no_mask_embed` to indicate no mask input in this case).
+ sam_mask_prompt = None
+
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+ points=(sam_point_coords, sam_point_labels),
+ boxes=None,
+ masks=sam_mask_prompt,
+ )
+ (
+ low_res_multimasks,
+ ious,
+ sam_output_tokens,
+ object_score_logits,
+ ) = self.sam_mask_decoder(
+ image_embeddings=backbone_features,
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=False, # the image is already batched
+ high_res_features=high_res_features,
+ )
+ if self.pred_obj_scores:
+ is_obj_appearing = object_score_logits > 0
+
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ low_res_multimasks = low_res_multimasks.float()
+ high_res_multimasks = F.interpolate(
+ low_res_multimasks,
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ sam_output_token = sam_output_tokens[:, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(ious, dim=-1)
+ batch_inds = torch.arange(B, device=device)
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ if sam_output_tokens.size(1) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
+ if self.pred_obj_scores:
+ # Allow *soft* no obj ptr, unlike for masks
+ if self.soft_no_obj_ptr:
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
+ else:
+ lambda_is_obj_appearing = is_obj_appearing.float()
+
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
+ """
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
+ (same input and output shapes as in _forward_sam_heads above).
+ """
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
+ mask_inputs_float = mask_inputs.float()
+ high_res_masks = mask_inputs_float * out_scale + out_bias
+ low_res_masks = F.interpolate(
+ high_res_masks,
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ # a dummy IoU prediction of all 1's under mask input
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
+ if not self.use_obj_ptrs_in_encoder:
+ # all zeros as a dummy object pointer (of shape [B, C])
+ obj_ptr = torch.zeros(
+ mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
+ )
+ else:
+ # produce an object pointer using the SAM decoder from the mask input
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
+ backbone_features=backbone_features,
+ mask_inputs=self.mask_downsample(mask_inputs_float),
+ high_res_features=high_res_features,
+ )
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+ # on the object_scores from the SAM decoder.
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+ is_obj_appearing = is_obj_appearing[..., None]
+ lambda_is_obj_appearing = is_obj_appearing.float()
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+ if self.pred_obj_scores:
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_masks,
+ high_res_masks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def forward_image(self, img_batch: torch.Tensor):
+ """Get the image feature on the input batch."""
+ backbone_out = self.image_encoder(img_batch)
+ if self.use_high_res_features_in_sam:
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
+ backbone_out["backbone_fpn"][0]
+ )
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
+ backbone_out["backbone_fpn"][1]
+ )
+ return backbone_out
+
+ def _prepare_backbone_features(self, backbone_out):
+ """Prepare and flatten visual features."""
+ backbone_out = backbone_out.copy()
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
+
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
+
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
+ # flatten NxCxHxW to HWxNxC
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
+
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
+
+ def _prepare_memory_conditioned_features(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ ):
+ """Fuse the current frame's visual feature map with previous memory."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ device = current_vision_feats[-1].device
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
+ # In this case, we skip the fusion with any memory.
+ if self.num_maskmem == 0: # Disable memory and skip fusion
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat
+
+ num_obj_ptr_tokens = 0
+ tpos_sign_mul = -1 if track_in_reverse else 1
+ # Step 1: condition the visual features of the current frame on previous memories
+ if not is_init_cond_frame:
+ # Retrieve the memories encoded with the maskmem backbone
+ to_cat_memory, to_cat_memory_pos_embed = [], []
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
+ # when getting temporal positional embedding below)
+ assert len(output_dict["cond_frame_outputs"]) > 0
+ # Select a maximum number of temporally closest cond frames for cross attention
+ cond_outputs = output_dict["cond_frame_outputs"]
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
+ )
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
+ # We also allow taking the memory frame non-consecutively (with stride>1), in which case
+ # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
+ stride = 1 if self.training else self.memory_temporal_stride_for_eval
+ for t_pos in range(1, self.num_maskmem):
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
+ if t_rel == 1:
+ # for t_rel == 1, we take the last frame (regardless of r)
+ if not track_in_reverse:
+ # the frame immediately before this frame (i.e. frame_idx - 1)
+ prev_frame_idx = frame_idx - t_rel
+ else:
+ # the frame immediately after this frame (i.e. frame_idx + 1)
+ prev_frame_idx = frame_idx + t_rel
+ else:
+ # for t_rel >= 2, we take the memory frame from every r-th frames
+ if not track_in_reverse:
+ # first find the nearest frame among every r-th frames before this frame
+ # for r=1, this would be (frame_idx - 2)
+ prev_frame_idx = ((frame_idx - 2) // stride) * stride
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
+ else:
+ # first find the nearest frame among every r-th frames after this frame
+ # for r=1, this would be (frame_idx + 2)
+ prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
+ if out is None:
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
+ # frames, we still attend to it as if it's a non-conditioning frame.
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
+ t_pos_and_prevs.append((t_pos, out))
+
+ for t_pos, prev in t_pos_and_prevs:
+ if prev is None:
+ continue # skip padding frames
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
+ feats = prev["maskmem_features"].to(device, non_blocking=True)
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
+ maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
+ # Temporal positional encoding
+ maskmem_enc = (
+ maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
+ )
+ to_cat_memory_pos_embed.append(maskmem_enc)
+
+ # Construct the list of past object pointers
+ if self.use_obj_ptrs_in_encoder:
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
+ # First add those object pointers from selected conditioning frames
+ # (optionally, only include object pointers in the past during evaluation)
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
+ ptr_cond_outputs = {
+ t: out
+ for t, out in selected_cond_outputs.items()
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
+ }
+ else:
+ ptr_cond_outputs = selected_cond_outputs
+ pos_and_ptrs = [
+ # Temporal pos encoding contains how far away each pointer is from current frame
+ (
+ (
+ (frame_idx - t) * tpos_sign_mul
+ if self.use_signed_tpos_enc_to_obj_ptrs
+ else abs(frame_idx - t)
+ ),
+ out["obj_ptr"],
+ )
+ for t, out in ptr_cond_outputs.items()
+ ]
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
+ if t < 0 or (num_frames is not None and t >= num_frames):
+ break
+ out = output_dict["non_cond_frame_outputs"].get(
+ t, unselected_cond_outputs.get(t, None)
+ )
+ if out is not None:
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
+ # If we have at least one object pointer, add them to the across attention
+ if len(pos_and_ptrs) > 0:
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
+ # a temporal positional embedding based on how far each object pointer is from
+ # the current frame (sine embedding normalized by the max pointer num).
+ if self.add_tpos_enc_to_obj_ptrs:
+ t_diff_max = max_obj_ptrs_in_encoder - 1
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
+ obj_pos = torch.tensor(pos_list).to(
+ device=device, non_blocking=True
+ )
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
+ else:
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
+ if self.mem_dim < C:
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
+ obj_ptrs = obj_ptrs.reshape(
+ -1, B, C // self.mem_dim, self.mem_dim
+ )
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
+ to_cat_memory.append(obj_ptrs)
+ to_cat_memory_pos_embed.append(obj_pos)
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
+ else:
+ num_obj_ptr_tokens = 0
+ else:
+ # for initial conditioning frames, encode them without using any previous memory
+ if self.directly_add_no_mem_embed:
+ # directly add no-mem embedding (instead of using the transformer encoder)
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
+
+ # Step 2: Concatenate the memories and forward through the transformer encoder
+ memory = torch.cat(to_cat_memory, dim=0)
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
+
+ pix_feat_with_mem = self.memory_attention(
+ curr=current_vision_feats,
+ curr_pos=current_vision_pos_embeds,
+ memory=memory,
+ memory_pos=memory_pos_embed,
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
+ )
+ # reshape the output (HW)BC => BCHW
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats,
+ feat_sizes,
+ pred_masks_high_res,
+ object_score_logits,
+ is_mask_from_pts,
+ ):
+ """Encode the current image and its prediction into a memory feature."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ if self.non_overlap_masks_for_mem_enc and not self.training:
+ # optionally, apply non-overlapping constraints to the masks (it's applied
+ # in the batch dimension and should only be used during eval, where all
+ # the objects come from the same video under batch size 1).
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
+ pred_masks_high_res
+ )
+ # scale the raw mask logits with a temperature before applying sigmoid
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+ if binarize and not self.training:
+ mask_for_mem = (pred_masks_high_res > 0).float()
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ if self.sigmoid_scale_for_mem_enc != 1.0:
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+ if self.sigmoid_bias_for_mem_enc != 0.0:
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+ maskmem_out = self.memory_encoder(
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
+ )
+ maskmem_features = maskmem_out["vision_features"]
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
+ # add a no-object embedding to the spatial memory to indicate that the frame
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
+ if self.no_obj_embed_spatial is not None:
+ is_obj_appearing = (object_score_logits > 0).float()
+ maskmem_features += (
+ 1 - is_obj_appearing[..., None, None]
+ ) * self.no_obj_embed_spatial[..., None, None].expand(
+ *maskmem_features.shape
+ )
+
+ return maskmem_features, maskmem_pos_enc
+
+ def _track_step(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse,
+ prev_sam_mask_logits,
+ ):
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+ if len(current_vision_feats) > 1:
+ high_res_features = [
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
+ ]
+ else:
+ high_res_features = None
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
+ sam_outputs = self._use_mask_as_output(
+ pix_feat, high_res_features, mask_inputs
+ )
+ else:
+ # fused the visual feature with previous memory features in the memory bank
+ pix_feat = self._prepare_memory_conditioned_features(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats[-1:],
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
+ feat_sizes=feat_sizes[-1:],
+ output_dict=output_dict,
+ num_frames=num_frames,
+ track_in_reverse=track_in_reverse,
+ )
+ # apply SAM-style segmentation head
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+ if prev_sam_mask_logits is not None:
+ assert point_inputs is not None and mask_inputs is None
+ mask_inputs = prev_sam_mask_logits
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ sam_outputs = self._forward_sam_heads(
+ backbone_features=pix_feat,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ high_res_features=high_res_features,
+ multimask_output=multimask_output,
+ )
+
+ return current_out, sam_outputs, high_res_features, pix_feat
+
+ def _encode_memory_in_output(
+ self,
+ current_vision_feats,
+ feat_sizes,
+ point_inputs,
+ run_mem_encoder,
+ high_res_masks,
+ object_score_logits,
+ current_out,
+ ):
+ if run_mem_encoder and self.num_maskmem > 0:
+ high_res_masks_for_mem_enc = high_res_masks
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks_for_mem_enc,
+ object_score_logits=object_score_logits,
+ is_mask_from_pts=(point_inputs is not None),
+ )
+ current_out["maskmem_features"] = maskmem_features
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ current_out["maskmem_features"] = None
+ current_out["maskmem_pos_enc"] = None
+
+ def track_step(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
+ # in demo we might call `track_step` multiple times for each user click,
+ # and only encode the memory when the user finalizes their clicks. And in ablation
+ # settings like SAM training on static images, we don't need the memory encoder.
+ run_mem_encoder=True,
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
+ prev_sam_mask_logits=None,
+ ):
+ current_out, sam_outputs, _, _ = self._track_step(
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse,
+ prev_sam_mask_logits,
+ )
+
+ (
+ _,
+ _,
+ _,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ ) = sam_outputs
+
+ current_out["pred_masks"] = low_res_masks
+ current_out["pred_masks_high_res"] = high_res_masks
+ current_out["obj_ptr"] = obj_ptr
+ if not self.training:
+ # Only add this in inference (to avoid unused param in activation checkpointing;
+ # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
+ current_out["object_score_logits"] = object_score_logits
+
+ # Finally run the memory encoder on the predicted mask to encode
+ # it into a new memory feature (that can be used in future frames)
+ self._encode_memory_in_output(
+ current_vision_feats,
+ feat_sizes,
+ point_inputs,
+ run_mem_encoder,
+ high_res_masks,
+ object_score_logits,
+ current_out,
+ )
+
+ return current_out
+
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
+ """Whether to use multimask output in the SAM head."""
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
+ multimask_output = (
+ self.multimask_output_in_sam
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
+ )
+ return multimask_output
+
+ def _apply_non_overlapping_constraints(self, pred_masks):
+ """
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
+ keep only the highest scoring object at each spatial location in pred_masks.
+ """
+ batch_size = pred_masks.size(0)
+ if batch_size == 1:
+ return pred_masks
+
+ device = pred_masks.device
+ # "max_obj_inds": object index of the object with the highest score at each location
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+ keep = max_obj_inds == batch_obj_inds
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+ return pred_masks
diff --git a/phantom/submodules/sam2/sam2/modeling/sam2_utils.py b/phantom/submodules/sam2/sam2/modeling/sam2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e16caae3a9a49e451b2d03d1ee60c47f8e9ed23c
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/modeling/sam2_utils.py
@@ -0,0 +1,323 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import copy
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sam2.utils.misc import mask_to_box
+
+
+def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
+ """
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
+ - a) the closest conditioning frame before `frame_idx` (if any);
+ - b) the closest conditioning frame after `frame_idx` (if any);
+ - c) any other temporally closest conditioning frames until reaching a total
+ of `max_cond_frame_num` conditioning frames.
+
+ Outputs:
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
+ """
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
+ selected_outputs = cond_frame_outputs
+ unselected_outputs = {}
+ else:
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
+ selected_outputs = {}
+
+ # the closest conditioning frame before `frame_idx` (if any)
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
+ if idx_before is not None:
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
+
+ # the closest conditioning frame after `frame_idx` (if any)
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
+ if idx_after is not None:
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
+
+ # add other temporally closest conditioning frames until reaching a total
+ # of `max_cond_frame_num` conditioning frames.
+ num_remain = max_cond_frame_num - len(selected_outputs)
+ inds_remain = sorted(
+ (t for t in cond_frame_outputs if t not in selected_outputs),
+ key=lambda x: abs(x - frame_idx),
+ )[:num_remain]
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
+ unselected_outputs = {
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
+ }
+
+ return selected_outputs, unselected_outputs
+
+
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+ """
+ Get 1D sine positional embedding as in the original Transformer paper.
+ """
+ pe_dim = dim // 2
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+ return pos_embed
+
+
+def get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+
+def get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class DropPath(nn.Module):
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ if self.drop_prob == 0.0 or not self.training:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and self.scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ activation: nn.Module = nn.ReLU,
+ sigmoid_output: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+ self.sigmoid_output = sigmoid_output
+ self.act = activation()
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+def sample_box_points(
+ masks: torch.Tensor,
+ noise: float = 0.1, # SAM default
+ noise_bound: int = 20, # SAM default
+ top_left_label: int = 2,
+ bottom_right_label: int = 3,
+) -> Tuple[np.array, np.array]:
+ """
+ Sample a noised version of the top left and bottom right corners of a given `bbox`
+
+ Inputs:
+ - masks: [B, 1, H,W] boxes, dtype=torch.Tensor
+ - noise: noise as a fraction of box width and height, dtype=float
+ - noise_bound: maximum amount of noise (in pure pixesl), dtype=int
+
+ Returns:
+ - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
+ - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
+ """
+ device = masks.device
+ box_coords = mask_to_box(masks)
+ B, _, H, W = masks.shape
+ box_labels = torch.tensor(
+ [top_left_label, bottom_right_label], dtype=torch.int, device=device
+ ).repeat(B)
+ if noise > 0.0:
+ if not isinstance(noise_bound, torch.Tensor):
+ noise_bound = torch.tensor(noise_bound, device=device)
+ bbox_w = box_coords[..., 2] - box_coords[..., 0]
+ bbox_h = box_coords[..., 3] - box_coords[..., 1]
+ max_dx = torch.min(bbox_w * noise, noise_bound)
+ max_dy = torch.min(bbox_h * noise, noise_bound)
+ box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
+ box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
+
+ box_coords = box_coords + box_noise
+ img_bounds = (
+ torch.tensor([W, H, W, H], device=device) - 1
+ ) # uncentered pixel coords
+ box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
+
+ box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
+ box_labels = box_labels.reshape(-1, 2)
+ return box_coords, box_labels
+
+
+def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
+ """
+ Sample `num_pt` random points (along with their labels) independently from the error regions.
+
+ Inputs:
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
+ - num_pt: int, number of points to sample independently for each of the B error maps
+
+ Outputs:
+ - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
+ - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
+ negative clicks
+ """
+ if pred_masks is None: # if pred_masks is not provided, treat it as empty
+ pred_masks = torch.zeros_like(gt_masks)
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
+ assert num_pt >= 0
+
+ B, _, H_im, W_im = gt_masks.shape
+ device = gt_masks.device
+
+ # false positive region, a new point sampled in this region should have
+ # negative label to correct the FP error
+ fp_masks = ~gt_masks & pred_masks
+ # false negative region, a new point sampled in this region should have
+ # positive label to correct the FN error
+ fn_masks = gt_masks & ~pred_masks
+ # whether the prediction completely match the ground-truth on each mask
+ all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
+ all_correct = all_correct[..., None, None]
+
+ # channel 0 is FP map, while channel 1 is FN map
+ pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
+ # sample a negative new click from FP region or a positive new click
+ # from FN region, depend on where the maximum falls,
+ # and in case the predictions are all correct (no FP or FN), we just
+ # sample a negative click from the background region
+ pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
+ pts_noise[..., 1] *= fn_masks
+ pts_idx = pts_noise.flatten(2).argmax(dim=2)
+ labels = (pts_idx % 2).to(torch.int32)
+ pts_idx = pts_idx // 2
+ pts_x = pts_idx % W_im
+ pts_y = pts_idx // W_im
+ points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
+ return points, labels
+
+
+def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
+ """
+ Sample 1 random point (along with its label) from the center of each error region,
+ that is, the point with the largest distance to the boundary of each error region.
+ This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
+
+ Inputs:
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
+ - padding: if True, pad with boundary of 1 px for distance transform
+
+ Outputs:
+ - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
+ - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
+ """
+ import cv2
+
+ if pred_masks is None:
+ pred_masks = torch.zeros_like(gt_masks)
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
+
+ B, _, _, W_im = gt_masks.shape
+ device = gt_masks.device
+
+ # false positive region, a new point sampled in this region should have
+ # negative label to correct the FP error
+ fp_masks = ~gt_masks & pred_masks
+ # false negative region, a new point sampled in this region should have
+ # positive label to correct the FN error
+ fn_masks = gt_masks & ~pred_masks
+
+ fp_masks = fp_masks.cpu().numpy()
+ fn_masks = fn_masks.cpu().numpy()
+ points = torch.zeros(B, 1, 2, dtype=torch.float)
+ labels = torch.ones(B, 1, dtype=torch.int32)
+ for b in range(B):
+ fn_mask = fn_masks[b, 0]
+ fp_mask = fp_masks[b, 0]
+ if padding:
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
+ # compute the distance of each point in FN/FP region to its boundary
+ fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
+ fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
+ if padding:
+ fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
+ fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
+
+ # take the point in FN/FP region with the largest distance to its boundary
+ fn_mask_dt_flat = fn_mask_dt.reshape(-1)
+ fp_mask_dt_flat = fp_mask_dt.reshape(-1)
+ fn_argmax = np.argmax(fn_mask_dt_flat)
+ fp_argmax = np.argmax(fp_mask_dt_flat)
+ is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
+ pt_idx = fn_argmax if is_positive else fp_argmax
+ points[b, 0, 0] = pt_idx % W_im # x
+ points[b, 0, 1] = pt_idx // W_im # y
+ labels[b, 0] = int(is_positive)
+
+ points = points.to(device)
+ labels = labels.to(device)
+ return points, labels
+
+
+def get_next_point(gt_masks, pred_masks, method):
+ if method == "uniform":
+ return sample_random_points_from_errors(gt_masks, pred_masks)
+ elif method == "center":
+ return sample_one_point_from_error_center(gt_masks, pred_masks)
+ else:
+ raise ValueError(f"unknown sampling method {method}")
diff --git a/phantom/submodules/sam2/sam2/sam2_hiera_b+.yaml b/phantom/submodules/sam2/sam2/sam2_hiera_b+.yaml
new file mode 120000
index 0000000000000000000000000000000000000000..998d9c98c9ff4e8ddd55deff72aa0d9067977418
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/sam2_hiera_b+.yaml
@@ -0,0 +1 @@
+configs/sam2/sam2_hiera_b+.yaml
\ No newline at end of file
diff --git a/phantom/submodules/sam2/sam2/sam2_hiera_l.yaml b/phantom/submodules/sam2/sam2/sam2_hiera_l.yaml
new file mode 120000
index 0000000000000000000000000000000000000000..c0e7e58e1951d5c55a3a3ebe6b803dd814cf9d86
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/sam2_hiera_l.yaml
@@ -0,0 +1 @@
+configs/sam2/sam2_hiera_l.yaml
\ No newline at end of file
diff --git a/phantom/submodules/sam2/sam2/sam2_hiera_s.yaml b/phantom/submodules/sam2/sam2/sam2_hiera_s.yaml
new file mode 120000
index 0000000000000000000000000000000000000000..41896a26beb2aa831d18b0bf3c349ed43deeef68
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/sam2_hiera_s.yaml
@@ -0,0 +1 @@
+configs/sam2/sam2_hiera_s.yaml
\ No newline at end of file
diff --git a/phantom/submodules/sam2/sam2/sam2_hiera_t.yaml b/phantom/submodules/sam2/sam2/sam2_hiera_t.yaml
new file mode 120000
index 0000000000000000000000000000000000000000..71ff3abbb1e11f8b82100a0a1d63cb267eefe52a
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/sam2_hiera_t.yaml
@@ -0,0 +1 @@
+configs/sam2/sam2_hiera_t.yaml
\ No newline at end of file
diff --git a/phantom/submodules/sam2/sam2/sam2_image_predictor.py b/phantom/submodules/sam2/sam2/sam2_image_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..41ce53af5924504c07216df52b2d2eefaeec7ae9
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/sam2_image_predictor.py
@@ -0,0 +1,466 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL.Image import Image
+
+from sam2.modeling.sam2_base import SAM2Base
+
+from sam2.utils.transforms import SAM2Transforms
+
+
+class SAM2ImagePredictor:
+ def __init__(
+ self,
+ sam_model: SAM2Base,
+ mask_threshold=0.0,
+ max_hole_area=0.0,
+ max_sprinkle_area=0.0,
+ **kwargs,
+ ) -> None:
+ """
+ Uses SAM-2 to calculate the image embedding for an image, and then
+ allow repeated, efficient mask prediction given prompts.
+
+ Arguments:
+ sam_model (Sam-2): The model to use for mask prediction.
+ mask_threshold (float): The threshold to use when converting mask logits
+ to binary masks. Masks are thresholded at 0 by default.
+ max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
+ the maximum area of max_hole_area in low_res_masks.
+ max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
+ the maximum area of max_sprinkle_area in low_res_masks.
+ """
+ super().__init__()
+ self.model = sam_model
+ self._transforms = SAM2Transforms(
+ resolution=self.model.image_size,
+ mask_threshold=mask_threshold,
+ max_hole_area=max_hole_area,
+ max_sprinkle_area=max_sprinkle_area,
+ )
+
+ # Predictor state
+ self._is_image_set = False
+ self._features = None
+ self._orig_hw = None
+ # Whether the predictor is set for single image or a batch of images
+ self._is_batch = False
+
+ # Predictor config
+ self.mask_threshold = mask_threshold
+
+ # Spatial dim for backbone feature maps
+ self._bb_feat_sizes = [
+ (256, 256),
+ (128, 128),
+ (64, 64),
+ ]
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
+ """
+ Load a pretrained model from the Hugging Face hub.
+
+ Arguments:
+ model_id (str): The Hugging Face repository ID.
+ **kwargs: Additional arguments to pass to the model constructor.
+
+ Returns:
+ (SAM2ImagePredictor): The loaded model.
+ """
+ from sam2.build_sam import build_sam2_hf
+
+ sam_model = build_sam2_hf(model_id, **kwargs)
+ return cls(sam_model, **kwargs)
+
+ @torch.no_grad()
+ def set_image(
+ self,
+ image: Union[np.ndarray, Image],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method.
+
+ Arguments:
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
+ with pixel values in [0, 255].
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
+ """
+ self.reset_predictor()
+ # Transform the image to the form expected by the model
+ if isinstance(image, np.ndarray):
+ logging.info("For numpy array image, we assume (HxWxC) format")
+ self._orig_hw = [image.shape[:2]]
+ elif isinstance(image, Image):
+ w, h = image.size
+ self._orig_hw = [(h, w)]
+ else:
+ raise NotImplementedError("Image format not supported")
+
+ input_image = self._transforms(image)
+ input_image = input_image[None, ...].to(self.device)
+
+ assert (
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
+ logging.info("Computing image embeddings for the provided image...")
+ backbone_out = self.model.forward_image(input_image)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+ feats = [
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+ self._is_image_set = True
+ logging.info("Image embeddings computed.")
+
+ @torch.no_grad()
+ def set_image_batch(
+ self,
+ image_list: List[Union[np.ndarray]],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image batch, allowing
+ masks to be predicted with the 'predict_batch' method.
+
+ Arguments:
+ image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
+ with pixel values in [0, 255].
+ """
+ self.reset_predictor()
+ assert isinstance(image_list, list)
+ self._orig_hw = []
+ for image in image_list:
+ assert isinstance(
+ image, np.ndarray
+ ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
+ self._orig_hw.append(image.shape[:2])
+ # Transform the image to the form expected by the model
+ img_batch = self._transforms.forward_batch(image_list)
+ img_batch = img_batch.to(self.device)
+ batch_size = img_batch.shape[0]
+ assert (
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
+ ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
+ logging.info("Computing image embeddings for the provided images...")
+ backbone_out = self.model.forward_image(img_batch)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+ feats = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+ self._is_image_set = True
+ self._is_batch = True
+ logging.info("Image embeddings computed.")
+
+ def predict_batch(
+ self,
+ point_coords_batch: List[np.ndarray] = None,
+ point_labels_batch: List[np.ndarray] = None,
+ box_batch: List[np.ndarray] = None,
+ mask_input_batch: List[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ normalize_coords=True,
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
+ """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
+ It returns a tuple of lists of masks, ious, and low_res_masks_logits.
+ """
+ assert self._is_batch, "This function should only be used when in batched mode"
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image_batch(...) before mask prediction."
+ )
+ num_images = len(self._features["image_embed"])
+ all_masks = []
+ all_ious = []
+ all_low_res_masks = []
+ for img_idx in range(num_images):
+ # Transform input prompts
+ point_coords = (
+ point_coords_batch[img_idx] if point_coords_batch is not None else None
+ )
+ point_labels = (
+ point_labels_batch[img_idx] if point_labels_batch is not None else None
+ )
+ box = box_batch[img_idx] if box_batch is not None else None
+ mask_input = (
+ mask_input_batch[img_idx] if mask_input_batch is not None else None
+ )
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
+ point_coords,
+ point_labels,
+ box,
+ mask_input,
+ normalize_coords,
+ img_idx=img_idx,
+ )
+ masks, iou_predictions, low_res_masks = self._predict(
+ unnorm_coords,
+ labels,
+ unnorm_box,
+ mask_input,
+ multimask_output,
+ return_logits=return_logits,
+ img_idx=img_idx,
+ )
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
+ iou_predictions_np = (
+ iou_predictions.squeeze(0).float().detach().cpu().numpy()
+ )
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
+ all_masks.append(masks_np)
+ all_ious.append(iou_predictions_np)
+ all_low_res_masks.append(low_res_masks_np)
+
+ return all_masks, all_ious, all_low_res_masks
+
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ mask_input: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ normalize_coords=True,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
+
+ # Transform input prompts
+
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
+ point_coords, point_labels, box, mask_input, normalize_coords
+ )
+
+ masks, iou_predictions, low_res_masks = self._predict(
+ unnorm_coords,
+ labels,
+ unnorm_box,
+ mask_input,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
+ iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
+ return masks_np, iou_predictions_np, low_res_masks_np
+
+ def _prep_prompts(
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
+ ):
+
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = torch.as_tensor(
+ point_coords, dtype=torch.float, device=self.device
+ )
+ unnorm_coords = self._transforms.transform_coords(
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
+ )
+ labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+ if len(unnorm_coords.shape) == 2:
+ unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
+ if box is not None:
+ box = torch.as_tensor(box, dtype=torch.float, device=self.device)
+ unnorm_box = self._transforms.transform_boxes(
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
+ ) # Bx2x2
+ if mask_logits is not None:
+ mask_input = torch.as_tensor(
+ mask_logits, dtype=torch.float, device=self.device
+ )
+ if len(mask_input.shape) == 3:
+ mask_input = mask_input[None, :, :, :]
+ return mask_input, unnorm_coords, labels, unnorm_box
+
+ @torch.no_grad()
+ def _predict(
+ self,
+ point_coords: Optional[torch.Tensor],
+ point_labels: Optional[torch.Tensor],
+ boxes: Optional[torch.Tensor] = None,
+ mask_input: Optional[torch.Tensor] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ img_idx: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+ Input prompts are batched torch tensors and are expected to already be
+ transformed to the input frame using SAM2Transforms.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
+
+ if point_coords is not None:
+ concat_points = (point_coords, point_labels)
+ else:
+ concat_points = None
+
+ # Embed prompts
+ if boxes is not None:
+ box_coords = boxes.reshape(-1, 2, 2)
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
+ box_labels = box_labels.repeat(boxes.size(0), 1)
+ # we merge "boxes" and "points" into a single "concat_points" input (where
+ # boxes are added at the beginning) to sam_prompt_encoder
+ if concat_points is not None:
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
+ concat_points = (concat_coords, concat_labels)
+ else:
+ concat_points = (box_coords, box_labels)
+
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+ points=concat_points,
+ boxes=None,
+ masks=mask_input,
+ )
+
+ # Predict masks
+ batched_mode = (
+ concat_points is not None and concat_points[0].shape[0] > 1
+ ) # multi object prediction
+ high_res_features = [
+ feat_level[img_idx].unsqueeze(0)
+ for feat_level in self._features["high_res_feats"]
+ ]
+ low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
+ image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=batched_mode,
+ high_res_features=high_res_features,
+ )
+
+ # Upscale the masks to the original image resolution
+ masks = self._transforms.postprocess_masks(
+ low_res_masks, self._orig_hw[img_idx]
+ )
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
+ if not return_logits:
+ masks = masks > self.mask_threshold
+
+ return masks, iou_predictions, low_res_masks
+
+ def get_image_embedding(self) -> torch.Tensor:
+ """
+ Returns the image embeddings for the currently set image, with
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
+ """
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) to generate an embedding."
+ )
+ assert (
+ self._features is not None
+ ), "Features must exist if an image has been set."
+ return self._features["image_embed"]
+
+ @property
+ def device(self) -> torch.device:
+ return self.model.device
+
+ def reset_predictor(self) -> None:
+ """
+ Resets the image embeddings and other state variables.
+ """
+ self._is_image_set = False
+ self._features = None
+ self._orig_hw = None
+ self._is_batch = False
diff --git a/phantom/submodules/sam2/sam2/sam2_video_predictor.py b/phantom/submodules/sam2/sam2/sam2_video_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a7e1a01c4d6e89db0453ce982ea8a31b16651c8
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/sam2_video_predictor.py
@@ -0,0 +1,1223 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import warnings
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+
+from tqdm import tqdm
+
+from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
+from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
+
+
+class SAM2VideoPredictor(SAM2Base):
+ """The predictor class to handle user interactions and manage inference states."""
+
+ def __init__(
+ self,
+ fill_hole_area=0,
+ # whether to apply non-overlapping constraints on the output object masks
+ non_overlap_masks=False,
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
+ clear_non_cond_mem_around_input=False,
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
+ add_all_frames_to_correct_as_cond=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.fill_hole_area = fill_hole_area
+ self.non_overlap_masks = non_overlap_masks
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
+
+ @torch.inference_mode()
+ def init_state(
+ self,
+ video_path,
+ offload_video_to_cpu=False,
+ offload_state_to_cpu=False,
+ async_loading_frames=False,
+ ):
+ """Initialize an inference state."""
+ compute_device = self.device # device of the model
+ images, video_height, video_width = load_video_frames(
+ video_path=video_path,
+ image_size=self.image_size,
+ offload_video_to_cpu=offload_video_to_cpu,
+ async_loading_frames=async_loading_frames,
+ compute_device=compute_device,
+ )
+ inference_state = {}
+ inference_state["images"] = images
+ inference_state["num_frames"] = len(images)
+ # whether to offload the video frames to CPU memory
+ # turning on this option saves the GPU memory with only a very small overhead
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
+ # whether to offload the inference state to CPU memory
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
+ # and from 24 to 21 when tracking two objects)
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
+ # the original video height and width, used for resizing final output scores
+ inference_state["video_height"] = video_height
+ inference_state["video_width"] = video_width
+ inference_state["device"] = compute_device
+ if offload_state_to_cpu:
+ inference_state["storage_device"] = torch.device("cpu")
+ else:
+ inference_state["storage_device"] = compute_device
+ # inputs on each frame
+ inference_state["point_inputs_per_obj"] = {}
+ inference_state["mask_inputs_per_obj"] = {}
+ # visual features on a small number of recently visited frames for quick interactions
+ inference_state["cached_features"] = {}
+ # values that don't change across frames (so we only need to hold one copy of them)
+ inference_state["constants"] = {}
+ # mapping between client-side object id and model-side object index
+ inference_state["obj_id_to_idx"] = OrderedDict()
+ inference_state["obj_idx_to_id"] = OrderedDict()
+ inference_state["obj_ids"] = []
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
+ inference_state["output_dict_per_obj"] = {}
+ # A temporary storage to hold new outputs when user interact with a frame
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
+ inference_state["temp_output_dict_per_obj"] = {}
+ # Frames that already holds consolidated outputs from click or mask inputs
+ # (we directly use their consolidated outputs during tracking)
+ # metadata for each tracking frame (e.g. which direction it's tracked)
+ inference_state["frames_tracked_per_obj"] = {}
+ # Warm up the visual backbone and cache the image feature on frame 0
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
+ return inference_state
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
+ """
+ Load a pretrained model from the Hugging Face hub.
+
+ Arguments:
+ model_id (str): The Hugging Face repository ID.
+ **kwargs: Additional arguments to pass to the model constructor.
+
+ Returns:
+ (SAM2VideoPredictor): The loaded model.
+ """
+ from sam2.build_sam import build_sam2_video_predictor_hf
+
+ sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
+ return sam_model
+
+ def _obj_id_to_idx(self, inference_state, obj_id):
+ """Map client-side object id to model-side object index."""
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
+ if obj_idx is not None:
+ return obj_idx
+
+ # We always allow adding new objects (including after tracking starts).
+ allow_new_object = True
+ if allow_new_object:
+ # get the next object slot
+ obj_idx = len(inference_state["obj_id_to_idx"])
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
+ # set up input and output structures for this object
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
+ inference_state["output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ inference_state["frames_tracked_per_obj"][obj_idx] = {}
+ return obj_idx
+ else:
+ raise RuntimeError(
+ f"Cannot add new object id {obj_id} after tracking starts. "
+ f"All existing object ids: {inference_state['obj_ids']}. "
+ f"Please call 'reset_state' to restart from scratch."
+ )
+
+ def _obj_idx_to_id(self, inference_state, obj_idx):
+ """Map model-side object index to client-side object id."""
+ return inference_state["obj_idx_to_id"][obj_idx]
+
+ def _get_obj_num(self, inference_state):
+ """Get the total number of unique object ids received so far in this session."""
+ return len(inference_state["obj_idx_to_id"])
+
+ @torch.inference_mode()
+ def add_new_points_or_box(
+ self,
+ inference_state,
+ frame_idx,
+ obj_id,
+ points=None,
+ labels=None,
+ clear_old_points=True,
+ normalize_coords=True,
+ box=None,
+ ):
+ """Add new points to a frame."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+ if (points is not None) != (labels is not None):
+ raise ValueError("points and labels must be provided together")
+ if points is None and box is None:
+ raise ValueError("at least one of points or box must be provided as input")
+
+ if points is None:
+ points = torch.zeros(0, 2, dtype=torch.float32)
+ elif not isinstance(points, torch.Tensor):
+ points = torch.tensor(points, dtype=torch.float32)
+ if labels is None:
+ labels = torch.zeros(0, dtype=torch.int32)
+ elif not isinstance(labels, torch.Tensor):
+ labels = torch.tensor(labels, dtype=torch.int32)
+ if points.dim() == 2:
+ points = points.unsqueeze(0) # add batch dimension
+ if labels.dim() == 1:
+ labels = labels.unsqueeze(0) # add batch dimension
+
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
+ # along with the user-provided points (consistent with how SAM 2 is trained).
+ if box is not None:
+ if not clear_old_points:
+ raise ValueError(
+ "cannot add box without clearing old points, since "
+ "box prompt must be provided before any point prompt "
+ "(please use clear_old_points=True instead)"
+ )
+ if not isinstance(box, torch.Tensor):
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
+ box_coords = box.reshape(1, 2, 2)
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
+ box_labels = box_labels.reshape(1, 2)
+ points = torch.cat([box_coords, points], dim=1)
+ labels = torch.cat([box_labels, labels], dim=1)
+
+ if normalize_coords:
+ video_H = inference_state["video_height"]
+ video_W = inference_state["video_width"]
+ points = points / torch.tensor([video_W, video_H]).to(points.device)
+ # scale the (normalized) coordinates by the model's internal image size
+ points = points * self.image_size
+ points = points.to(inference_state["device"])
+ labels = labels.to(inference_state["device"])
+
+ if not clear_old_points:
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
+ else:
+ point_inputs = None
+ point_inputs = concat_points(point_inputs, points, labels)
+
+ point_inputs_per_frame[frame_idx] = point_inputs
+ mask_inputs_per_frame.pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
+ is_init_cond_frame = frame_idx not in obj_frames_tracked
+ # whether to track in reverse time order
+ if is_init_cond_frame:
+ reverse = False
+ else:
+ reverse = obj_frames_tracked[frame_idx]["reverse"]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ # Get any previously predicted mask logits on this object and feed it along with
+ # the new clicks into the SAM mask decoder.
+ prev_sam_mask_logits = None
+ # lookup temporary output dict first, which contains the most recent output
+ # (if not found, then lookup conditioning and non-conditioning frame output)
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
+ if prev_out is None:
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
+ if prev_out is None:
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+
+ if prev_out is not None and prev_out["pred_masks"] is not None:
+ device = inference_state["device"]
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
+ current_out, _ = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=obj_output_dict, # run on the slice of a single object
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=point_inputs,
+ mask_inputs=None,
+ reverse=reverse,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+ # Resize the output mask to the original video resolution
+ obj_ids = inference_state["obj_ids"]
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ def add_new_points(self, *args, **kwargs):
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
+ return self.add_new_points_or_box(*args, **kwargs)
+
+ @torch.inference_mode()
+ def add_new_mask(
+ self,
+ inference_state,
+ frame_idx,
+ obj_id,
+ mask,
+ ):
+ """Add new mask to a frame."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+ if not isinstance(mask, torch.Tensor):
+ mask = torch.tensor(mask, dtype=torch.bool)
+ assert mask.dim() == 2
+ mask_H, mask_W = mask.shape
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
+
+ # resize the mask if it doesn't match the model's image size
+ if mask_H != self.image_size or mask_W != self.image_size:
+ mask_inputs = torch.nn.functional.interpolate(
+ mask_inputs_orig,
+ size=(self.image_size, self.image_size),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ mask_inputs = (mask_inputs >= 0.5).float()
+ else:
+ mask_inputs = mask_inputs_orig
+
+ mask_inputs_per_frame[frame_idx] = mask_inputs
+ point_inputs_per_frame.pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
+ is_init_cond_frame = frame_idx not in obj_frames_tracked
+ # whether to track in reverse time order
+ if is_init_cond_frame:
+ reverse = False
+ else:
+ reverse = obj_frames_tracked[frame_idx]["reverse"]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ current_out, _ = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=obj_output_dict, # run on the slice of a single object
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=None,
+ mask_inputs=mask_inputs,
+ reverse=reverse,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ )
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+ # Resize the output mask to the original video resolution
+ obj_ids = inference_state["obj_ids"]
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
+ """
+ Resize the object scores to the original video resolution (video_res_masks)
+ and apply non-overlapping constraints for final output.
+ """
+ device = inference_state["device"]
+ video_H = inference_state["video_height"]
+ video_W = inference_state["video_width"]
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
+ if any_res_masks.shape[-2:] == (video_H, video_W):
+ video_res_masks = any_res_masks
+ else:
+ video_res_masks = torch.nn.functional.interpolate(
+ any_res_masks,
+ size=(video_H, video_W),
+ mode="bilinear",
+ align_corners=False,
+ )
+ if self.non_overlap_masks:
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
+ return any_res_masks, video_res_masks
+
+ def _consolidate_temp_output_across_obj(
+ self,
+ inference_state,
+ frame_idx,
+ is_cond,
+ consolidate_at_video_res=False,
+ ):
+ """
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
+ a frame into a single output for all objects, including
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
+ (if they don't exist in `output_dict_per_obj` for this frame);
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
+ on the object scores.
+ """
+ batch_size = self._get_obj_num(inference_state)
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ # Optionally, we allow consolidating the temporary outputs at the original
+ # video resolution (to provide a better editing experience for mask prompts).
+ if consolidate_at_video_res:
+ consolidated_H = inference_state["video_height"]
+ consolidated_W = inference_state["video_width"]
+ consolidated_mask_key = "pred_masks_video_res"
+ else:
+ consolidated_H = consolidated_W = self.image_size // 4
+ consolidated_mask_key = "pred_masks"
+
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
+ # will be added when rerunning the memory encoder after applying non-overlapping
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
+ consolidated_out = {
+ consolidated_mask_key: torch.full(
+ size=(batch_size, 1, consolidated_H, consolidated_W),
+ fill_value=NO_OBJ_SCORE,
+ dtype=torch.float32,
+ device=inference_state["storage_device"],
+ ),
+ }
+ for obj_idx in range(batch_size):
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
+ # we fall back and look up its previous output in "output_dict_per_obj".
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
+ # "output_dict_per_obj" to find a previous output for this object.
+ if out is None:
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
+ if out is None:
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
+ # placeholder above) and set its object pointer to be a dummy pointer.
+ if out is None:
+ continue
+ # Add the temporary object output mask to consolidated output mask
+ obj_mask = out["pred_masks"]
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
+ else:
+ # Resize first if temporary object mask has a different resolution
+ resized_obj_mask = torch.nn.functional.interpolate(
+ obj_mask,
+ size=consolidated_pred_masks.shape[-2:],
+ mode="bilinear",
+ align_corners=False,
+ )
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
+
+ return consolidated_out
+
+ @torch.inference_mode()
+ def propagate_in_video_preflight(self, inference_state):
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
+ # Check and make sure that every object has received input points or masks.
+ batch_size = self._get_obj_num(inference_state)
+ if batch_size == 0:
+ raise RuntimeError(
+ "No input points or masks are provided for any object; please add inputs first."
+ )
+
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
+ # add them into "output_dict".
+ for obj_idx in range(batch_size):
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ for is_cond in [False, True]:
+ # Separately consolidate conditioning and non-conditioning temp outputs
+ storage_key = (
+ "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ )
+ # Find all the frames that contain temporary outputs for any objects
+ # (these should be the frames that have just received clicks for mask inputs
+ # via `add_new_points_or_box` or `add_new_mask`)
+ for frame_idx, out in obj_temp_output_dict[storage_key].items():
+ # Run memory encoder on the temporary outputs (if the memory feature is missing)
+ if out["maskmem_features"] is None:
+ high_res_masks = torch.nn.functional.interpolate(
+ out["pred_masks"].to(inference_state["device"]),
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
+ inference_state=inference_state,
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ high_res_masks=high_res_masks,
+ object_score_logits=out["object_score_logits"],
+ # these frames are what the user interacted with
+ is_mask_from_pts=True,
+ )
+ out["maskmem_features"] = maskmem_features
+ out["maskmem_pos_enc"] = maskmem_pos_enc
+
+ obj_output_dict[storage_key][frame_idx] = out
+ if self.clear_non_cond_mem_around_input:
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_obj_non_cond_mem_around_input(
+ inference_state, frame_idx, obj_idx
+ )
+
+ # clear temporary outputs in `temp_output_dict_per_obj`
+ obj_temp_output_dict[storage_key].clear()
+
+ # check and make sure that every object has received input points or masks
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ if len(obj_output_dict["cond_frame_outputs"]) == 0:
+ obj_id = self._obj_idx_to_id(inference_state, obj_idx)
+ raise RuntimeError(
+ f"No input points or masks are provided for object id {obj_id}; please add inputs first."
+ )
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
+ # output on the same frame in "non_cond_frame_outputs"
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+
+ @torch.inference_mode()
+ def propagate_in_video(
+ self,
+ inference_state,
+ start_frame_idx=None,
+ max_frame_num_to_track=None,
+ reverse=False,
+ ):
+ """Propagate the input points across frames to track in the entire video."""
+ self.propagate_in_video_preflight(inference_state)
+
+ obj_ids = inference_state["obj_ids"]
+ num_frames = inference_state["num_frames"]
+ batch_size = self._get_obj_num(inference_state)
+
+ # set start index, end index, and processing order
+ if start_frame_idx is None:
+ # default: start from the earliest frame with input points
+ start_frame_idx = min(
+ t
+ for obj_output_dict in inference_state["output_dict_per_obj"].values()
+ for t in obj_output_dict["cond_frame_outputs"]
+ )
+ if max_frame_num_to_track is None:
+ # default: track all the frames in the video
+ max_frame_num_to_track = num_frames
+ if reverse:
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
+ if start_frame_idx > 0:
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
+ else:
+ processing_order = [] # skip reverse tracking if starting from frame 0
+ else:
+ end_frame_idx = min(
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
+ )
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
+
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
+ pred_masks_per_obj = [None] * batch_size
+ for obj_idx in range(batch_size):
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ # We skip those frames already in consolidated outputs (these are frames
+ # that received input clicks or mask). Note that we cannot directly run
+ # batched forward on them via `_run_single_frame_inference` because the
+ # number of clicks on each object might be different.
+ if frame_idx in obj_output_dict["cond_frame_outputs"]:
+ storage_key = "cond_frame_outputs"
+ current_out = obj_output_dict[storage_key][frame_idx]
+ device = inference_state["device"]
+ pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
+ if self.clear_non_cond_mem_around_input:
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_obj_non_cond_mem_around_input(
+ inference_state, frame_idx, obj_idx
+ )
+ else:
+ storage_key = "non_cond_frame_outputs"
+ current_out, pred_masks = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=obj_output_dict,
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=False,
+ point_inputs=None,
+ mask_inputs=None,
+ reverse=reverse,
+ run_mem_encoder=True,
+ )
+ obj_output_dict[storage_key][frame_idx] = current_out
+
+ inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
+ "reverse": reverse
+ }
+ pred_masks_per_obj[obj_idx] = pred_masks
+
+ # Resize the output mask to the original video resolution (we directly use
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
+ if len(pred_masks_per_obj) > 1:
+ all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
+ else:
+ all_pred_masks = pred_masks_per_obj[0]
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, all_pred_masks
+ )
+ yield frame_idx, obj_ids, video_res_masks
+
+ @torch.inference_mode()
+ def clear_all_prompts_in_frame(
+ self, inference_state, frame_idx, obj_id, need_output=True
+ ):
+ """Remove all input points or mask in a specific frame for a given object."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+
+ # Clear the conditioning information on the given frame
+ inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
+ inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
+
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
+ temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
+
+ # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
+ if out is not None:
+ # The frame is not a conditioning frame anymore since it's not receiving inputs,
+ # so we "downgrade" its output (if exists) to a non-conditioning frame output.
+ obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
+ inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
+
+ if not need_output:
+ return
+ # Finally, output updated masks per object (after removing the inputs above)
+ obj_ids = inference_state["obj_ids"]
+ is_cond = any(
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
+ )
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ @torch.inference_mode()
+ def reset_state(self, inference_state):
+ """Remove all input points or mask in all frames throughout the video."""
+ self._reset_tracking_results(inference_state)
+ # Remove all object ids
+ inference_state["obj_id_to_idx"].clear()
+ inference_state["obj_idx_to_id"].clear()
+ inference_state["obj_ids"].clear()
+ inference_state["point_inputs_per_obj"].clear()
+ inference_state["mask_inputs_per_obj"].clear()
+ inference_state["output_dict_per_obj"].clear()
+ inference_state["temp_output_dict_per_obj"].clear()
+ inference_state["frames_tracked_per_obj"].clear()
+
+ def _reset_tracking_results(self, inference_state):
+ """Reset all tracking inputs and results across the videos."""
+ for v in inference_state["point_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["mask_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ for v in inference_state["temp_output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ for v in inference_state["frames_tracked_per_obj"].values():
+ v.clear()
+
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
+ """Compute the image features on a given frame."""
+ # Look up in the cache first
+ image, backbone_out = inference_state["cached_features"].get(
+ frame_idx, (None, None)
+ )
+ if backbone_out is None:
+ # Cache miss -- we will run inference on a single image
+ device = inference_state["device"]
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
+ backbone_out = self.forward_image(image)
+ # Cache the most recent frame's feature (for repeated interactions with
+ # a frame; we can use an LRU cache for more frames in the future).
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
+
+ # expand the features to have the same dimension as the number of objects
+ expanded_image = image.expand(batch_size, -1, -1, -1)
+ expanded_backbone_out = {
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
+ }
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
+ batch_size, -1, -1, -1
+ )
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
+ pos = pos.expand(batch_size, -1, -1, -1)
+ expanded_backbone_out["vision_pos_enc"][i] = pos
+
+ features = self._prepare_backbone_features(expanded_backbone_out)
+ features = (expanded_image,) + features
+ return features
+
+ def _run_single_frame_inference(
+ self,
+ inference_state,
+ output_dict,
+ frame_idx,
+ batch_size,
+ is_init_cond_frame,
+ point_inputs,
+ mask_inputs,
+ reverse,
+ run_mem_encoder,
+ prev_sam_mask_logits=None,
+ ):
+ """Run tracking on a single frame based on current inputs and previous memory."""
+ # Retrieve correct image features
+ (
+ _,
+ _,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
+
+ # point and mask should not appear as input simultaneously on the same frame
+ assert point_inputs is None or mask_inputs is None
+ current_out = self.track_step(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ output_dict=output_dict,
+ num_frames=inference_state["num_frames"],
+ track_in_reverse=reverse,
+ run_mem_encoder=run_mem_encoder,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+
+ # optionally offload the output to CPU memory to save GPU space
+ storage_device = inference_state["storage_device"]
+ maskmem_features = current_out["maskmem_features"]
+ if maskmem_features is not None:
+ maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+ pred_masks_gpu = current_out["pred_masks"]
+ # potentially fill holes in the predicted masks
+ if self.fill_hole_area > 0:
+ pred_masks_gpu = fill_holes_in_mask_scores(
+ pred_masks_gpu, self.fill_hole_area
+ )
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
+ obj_ptr = current_out["obj_ptr"]
+ object_score_logits = current_out["object_score_logits"]
+ # make a compact version of this frame's output to reduce the state size
+ compact_current_out = {
+ "maskmem_features": maskmem_features,
+ "maskmem_pos_enc": maskmem_pos_enc,
+ "pred_masks": pred_masks,
+ "obj_ptr": obj_ptr,
+ "object_score_logits": object_score_logits,
+ }
+ return compact_current_out, pred_masks_gpu
+
+ def _run_memory_encoder(
+ self,
+ inference_state,
+ frame_idx,
+ batch_size,
+ high_res_masks,
+ object_score_logits,
+ is_mask_from_pts,
+ ):
+ """
+ Run the memory encoder on `high_res_masks`. This is usually after applying
+ non-overlapping constraints to object scores. Since their scores changed, their
+ memory also need to be computed again with the memory encoder.
+ """
+ # Retrieve correct image features
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
+ inference_state, frame_idx, batch_size
+ )
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks,
+ object_score_logits=object_score_logits,
+ is_mask_from_pts=is_mask_from_pts,
+ )
+
+ # optionally offload the output to CPU memory to save GPU space
+ storage_device = inference_state["storage_device"]
+ maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
+ )
+ return maskmem_features, maskmem_pos_enc
+
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
+ """
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
+ a constant in the inference session to reduce session storage size.
+ """
+ model_constants = inference_state["constants"]
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
+ if out_maskmem_pos_enc is not None:
+ if "maskmem_pos_enc" not in model_constants:
+ assert isinstance(out_maskmem_pos_enc, list)
+ # only take the slice for one object, since it's same across objects
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
+ # expand the cached maskmem_pos_enc to the actual batch size
+ batch_size = out_maskmem_pos_enc[0].size(0)
+ expanded_maskmem_pos_enc = [
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
+ ]
+ else:
+ expanded_maskmem_pos_enc = None
+ return expanded_maskmem_pos_enc
+
+ @torch.inference_mode()
+ def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
+ """
+ Remove an object id from the tracking state. If strict is True, we check whether
+ the object id actually exists and raise an error if it doesn't exist.
+ """
+ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
+ updated_frames = []
+ # Check whether this object_id to remove actually exists and possibly raise an error.
+ if old_obj_idx_to_rm is None:
+ if not strict:
+ return inference_state["obj_ids"], updated_frames
+ raise RuntimeError(
+ f"Cannot remove object id {obj_id} as it doesn't exist. "
+ f"All existing object ids: {inference_state['obj_ids']}."
+ )
+
+ # If this is the only remaining object id, we simply reset the state.
+ if len(inference_state["obj_id_to_idx"]) == 1:
+ self.reset_state(inference_state)
+ return inference_state["obj_ids"], updated_frames
+
+ # There are still remaining objects after removing this object id. In this case,
+ # we need to delete the object storage from inference state tensors.
+ # Step 0: clear the input on those frames where this object id has point or mask input
+ # (note that this step is required as it might downgrade conditioning frames to
+ # non-conditioning ones)
+ obj_input_frames_inds = set()
+ obj_input_frames_inds.update(
+ inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
+ )
+ obj_input_frames_inds.update(
+ inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
+ )
+ for frame_idx in obj_input_frames_inds:
+ self.clear_all_prompts_in_frame(
+ inference_state, frame_idx, obj_id, need_output=False
+ )
+
+ # Step 1: Update the object id mapping (note that it must be done after Step 0,
+ # since Step 0 still requires the old object id mappings in inference_state)
+ old_obj_ids = inference_state["obj_ids"]
+ old_obj_inds = list(range(len(old_obj_ids)))
+ remain_old_obj_inds = old_obj_inds.copy()
+ remain_old_obj_inds.remove(old_obj_idx_to_rm)
+ new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
+ new_obj_inds = list(range(len(new_obj_ids)))
+ # build new mappings
+ old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
+ inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
+ inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
+ inference_state["obj_ids"] = new_obj_ids
+
+ # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
+ def _map_keys(container):
+ new_kvs = []
+ for k in old_obj_inds:
+ v = container.pop(k)
+ if k in old_idx_to_new_idx:
+ new_kvs.append((old_idx_to_new_idx[k], v))
+ container.update(new_kvs)
+
+ _map_keys(inference_state["point_inputs_per_obj"])
+ _map_keys(inference_state["mask_inputs_per_obj"])
+ _map_keys(inference_state["output_dict_per_obj"])
+ _map_keys(inference_state["temp_output_dict_per_obj"])
+ _map_keys(inference_state["frames_tracked_per_obj"])
+
+ # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
+ # could show an updated mask for objects previously occluded by the object being removed
+ if need_output:
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ for frame_idx in obj_input_frames_inds:
+ is_cond = any(
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
+ )
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ updated_frames.append((frame_idx, video_res_masks))
+
+ return inference_state["obj_ids"], updated_frames
+
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
+ """
+ Remove the non-conditioning memory around the input frame. When users provide
+ correction clicks, the surrounding frames' non-conditioning memories can still
+ contain outdated object appearance information and could confuse the model.
+
+ This method clears those non-conditioning memories surrounding the interacted
+ frame to avoid giving the model both old and new information about the object.
+ """
+ r = self.memory_temporal_stride_for_eval
+ frame_idx_begin = frame_idx - r * self.num_maskmem
+ frame_idx_end = frame_idx + r * self.num_maskmem
+ batch_size = self._get_obj_num(inference_state)
+ for obj_idx in range(batch_size):
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
+ for t in range(frame_idx_begin, frame_idx_end + 1):
+ non_cond_frame_outputs.pop(t, None)
+
+
+class SAM2VideoPredictorVOS(SAM2VideoPredictor):
+ """Optimized for the VOS setting"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._compile_all_components()
+
+ def _compile_all_components(self):
+ print("Compiling all components for VOS setting. First time may be very slow.")
+ self.memory_encoder.forward = torch.compile(
+ self.memory_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False,
+ )
+
+ self.memory_attention.forward = torch.compile(
+ self.memory_attention.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=True, # Num. of memories varies
+ )
+
+ self.sam_prompt_encoder.forward = torch.compile(
+ self.sam_prompt_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False, # Accuracy regression on True
+ )
+
+ self.sam_mask_decoder.forward = torch.compile(
+ self.sam_mask_decoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False, # Accuracy regression on True
+ )
+
+ def forward_image(self, img_batch: torch.Tensor):
+ """
+ Identical to the corresponding method in the parent (SAM2VideoPredictor), but
+ cloning the backbone features and pos encoding to enable compilation.
+ """
+ backbone_out = self.image_encoder(img_batch)
+ if self.use_high_res_features_in_sam:
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
+ backbone_out["backbone_fpn"][0]
+ )
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
+ backbone_out["backbone_fpn"][1]
+ )
+ # Clone to help torch.compile
+ for i in range(len(backbone_out["backbone_fpn"])):
+ backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
+ backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
+ i
+ ].clone()
+ return backbone_out
+
+ def _forward_sam_heads(
+ self,
+ backbone_features,
+ point_inputs=None,
+ mask_inputs=None,
+ high_res_features=None,
+ multimask_output=False,
+ ):
+ """
+ Identical to the corresponding method in the parent (SAM2VideoPredictor), but
+ cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
+ """
+ B = backbone_features.size(0)
+ device = backbone_features.device
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
+ assert backbone_features.size(2) == self.sam_image_embedding_size
+ assert backbone_features.size(3) == self.sam_image_embedding_size
+
+ # a) Handle point prompts
+ if point_inputs is not None:
+ sam_point_coords = point_inputs["point_coords"]
+ sam_point_labels = point_inputs["point_labels"]
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+ else:
+ # If no points are provide, pad with an empty point (with label -1)
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+
+ # b) Handle mask prompts
+ if mask_inputs is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+ sam_mask_prompt = F.interpolate(
+ mask_inputs.float(),
+ size=self.sam_prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ else:
+ sam_mask_prompt = mask_inputs
+ else:
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
+ # a learned `no_mask_embed` to indicate no mask input in this case).
+ sam_mask_prompt = None
+
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+ points=(sam_point_coords, sam_point_labels),
+ boxes=None,
+ masks=sam_mask_prompt,
+ )
+ # Clone image_pe and the outputs of sam_prompt_encoder
+ # to enable compilation
+ sparse_embeddings = sparse_embeddings.clone()
+ dense_embeddings = dense_embeddings.clone()
+ image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
+ (
+ low_res_multimasks,
+ ious,
+ sam_output_tokens,
+ object_score_logits,
+ ) = self.sam_mask_decoder(
+ image_embeddings=backbone_features,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=False, # the image is already batched
+ high_res_features=high_res_features,
+ )
+ # Clone the output of sam_mask_decoder
+ # to enable compilation
+ low_res_multimasks = low_res_multimasks.clone()
+ ious = ious.clone()
+ sam_output_tokens = sam_output_tokens.clone()
+ object_score_logits = object_score_logits.clone()
+
+ if self.pred_obj_scores:
+ is_obj_appearing = object_score_logits > 0
+
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ low_res_multimasks = low_res_multimasks.float()
+ high_res_multimasks = F.interpolate(
+ low_res_multimasks,
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ sam_output_token = sam_output_tokens[:, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(ious, dim=-1)
+ batch_inds = torch.arange(B, device=device)
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ if sam_output_tokens.size(1) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
+ if self.pred_obj_scores:
+ # Allow *soft* no obj ptr, unlike for masks
+ if self.soft_no_obj_ptr:
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
+ else:
+ lambda_is_obj_appearing = is_obj_appearing.float()
+
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats,
+ feat_sizes,
+ pred_masks_high_res,
+ object_score_logits,
+ is_mask_from_pts,
+ ):
+ """
+ Identical to the corresponding method in the parent (SAM2VideoPredictor), but
+ cloning the memories and their pos enc to enable compilation.
+ """
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ if self.non_overlap_masks_for_mem_enc and not self.training:
+ # optionally, apply non-overlapping constraints to the masks (it's applied
+ # in the batch dimension and should only be used during eval, where all
+ # the objects come from the same video under batch size 1).
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
+ pred_masks_high_res
+ )
+ # scale the raw mask logits with a temperature before applying sigmoid
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+ if binarize and not self.training:
+ mask_for_mem = (pred_masks_high_res > 0).float()
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ if self.sigmoid_scale_for_mem_enc != 1.0:
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+ if self.sigmoid_bias_for_mem_enc != 0.0:
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+ maskmem_out = self.memory_encoder(
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
+ )
+ # Clone the feats and pos_enc to enable compilation
+ maskmem_features = maskmem_out["vision_features"].clone()
+ maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
+ # add a no-object embedding to the spatial memory to indicate that the frame
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
+ if self.no_obj_embed_spatial is not None:
+ is_obj_appearing = (object_score_logits > 0).float()
+ maskmem_features += (
+ 1 - is_obj_appearing[..., None, None]
+ ) * self.no_obj_embed_spatial[..., None, None].expand(
+ *maskmem_features.shape
+ )
+
+ return maskmem_features, maskmem_pos_enc
diff --git a/phantom/submodules/sam2/sam2/sam2_video_predictor_legacy.py b/phantom/submodules/sam2/sam2/sam2_video_predictor_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7e01ccf972491904b013526333826b337354db1
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/sam2_video_predictor_legacy.py
@@ -0,0 +1,1172 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import warnings
+from collections import OrderedDict
+
+import torch
+
+from tqdm import tqdm
+
+from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
+from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
+
+
+class SAM2VideoPredictor(SAM2Base):
+ """The predictor class to handle user interactions and manage inference states."""
+
+ def __init__(
+ self,
+ fill_hole_area=0,
+ # whether to apply non-overlapping constraints on the output object masks
+ non_overlap_masks=False,
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
+ clear_non_cond_mem_around_input=False,
+ # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
+ clear_non_cond_mem_for_multi_obj=False,
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
+ add_all_frames_to_correct_as_cond=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.fill_hole_area = fill_hole_area
+ self.non_overlap_masks = non_overlap_masks
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
+ self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
+
+ @torch.inference_mode()
+ def init_state(
+ self,
+ video_path,
+ offload_video_to_cpu=False,
+ offload_state_to_cpu=False,
+ async_loading_frames=False,
+ ):
+ """Initialize an inference state."""
+ compute_device = self.device # device of the model
+ images, video_height, video_width = load_video_frames(
+ video_path=video_path,
+ image_size=self.image_size,
+ offload_video_to_cpu=offload_video_to_cpu,
+ async_loading_frames=async_loading_frames,
+ compute_device=compute_device,
+ )
+ inference_state = {}
+ inference_state["images"] = images
+ inference_state["num_frames"] = len(images)
+ # whether to offload the video frames to CPU memory
+ # turning on this option saves the GPU memory with only a very small overhead
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
+ # whether to offload the inference state to CPU memory
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
+ # and from 24 to 21 when tracking two objects)
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
+ # the original video height and width, used for resizing final output scores
+ inference_state["video_height"] = video_height
+ inference_state["video_width"] = video_width
+ inference_state["device"] = compute_device
+ if offload_state_to_cpu:
+ inference_state["storage_device"] = torch.device("cpu")
+ else:
+ inference_state["storage_device"] = compute_device
+ # inputs on each frame
+ inference_state["point_inputs_per_obj"] = {}
+ inference_state["mask_inputs_per_obj"] = {}
+ # visual features on a small number of recently visited frames for quick interactions
+ inference_state["cached_features"] = {}
+ # values that don't change across frames (so we only need to hold one copy of them)
+ inference_state["constants"] = {}
+ # mapping between client-side object id and model-side object index
+ inference_state["obj_id_to_idx"] = OrderedDict()
+ inference_state["obj_idx_to_id"] = OrderedDict()
+ inference_state["obj_ids"] = []
+ # A storage to hold the model's tracking results and states on each frame
+ inference_state["output_dict"] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
+ inference_state["output_dict_per_obj"] = {}
+ # A temporary storage to hold new outputs when user interact with a frame
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
+ inference_state["temp_output_dict_per_obj"] = {}
+ # Frames that already holds consolidated outputs from click or mask inputs
+ # (we directly use their consolidated outputs during tracking)
+ inference_state["consolidated_frame_inds"] = {
+ "cond_frame_outputs": set(), # set containing frame indices
+ "non_cond_frame_outputs": set(), # set containing frame indices
+ }
+ # metadata for each tracking frame (e.g. which direction it's tracked)
+ inference_state["tracking_has_started"] = False
+ inference_state["frames_already_tracked"] = {}
+ # Warm up the visual backbone and cache the image feature on frame 0
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
+ return inference_state
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
+ """
+ Load a pretrained model from the Hugging Face hub.
+
+ Arguments:
+ model_id (str): The Hugging Face repository ID.
+ **kwargs: Additional arguments to pass to the model constructor.
+
+ Returns:
+ (SAM2VideoPredictor): The loaded model.
+ """
+ from sam2.build_sam import build_sam2_video_predictor_hf
+
+ sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
+ return sam_model
+
+ def _obj_id_to_idx(self, inference_state, obj_id):
+ """Map client-side object id to model-side object index."""
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
+ if obj_idx is not None:
+ return obj_idx
+
+ # This is a new object id not sent to the server before. We only allow adding
+ # new objects *before* the tracking starts.
+ allow_new_object = not inference_state["tracking_has_started"]
+ if allow_new_object:
+ # get the next object slot
+ obj_idx = len(inference_state["obj_id_to_idx"])
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
+ # set up input and output structures for this object
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
+ inference_state["output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ return obj_idx
+ else:
+ raise RuntimeError(
+ f"Cannot add new object id {obj_id} after tracking starts. "
+ f"All existing object ids: {inference_state['obj_ids']}. "
+ f"Please call 'reset_state' to restart from scratch."
+ )
+
+ def _obj_idx_to_id(self, inference_state, obj_idx):
+ """Map model-side object index to client-side object id."""
+ return inference_state["obj_idx_to_id"][obj_idx]
+
+ def _get_obj_num(self, inference_state):
+ """Get the total number of unique object ids received so far in this session."""
+ return len(inference_state["obj_idx_to_id"])
+
+ @torch.inference_mode()
+ def add_new_points_or_box(
+ self,
+ inference_state,
+ frame_idx,
+ obj_id,
+ points=None,
+ labels=None,
+ clear_old_points=True,
+ normalize_coords=True,
+ box=None,
+ ):
+ """Add new points to a frame."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+ if (points is not None) != (labels is not None):
+ raise ValueError("points and labels must be provided together")
+ if points is None and box is None:
+ raise ValueError("at least one of points or box must be provided as input")
+
+ if points is None:
+ points = torch.zeros(0, 2, dtype=torch.float32)
+ elif not isinstance(points, torch.Tensor):
+ points = torch.tensor(points, dtype=torch.float32)
+ if labels is None:
+ labels = torch.zeros(0, dtype=torch.int32)
+ elif not isinstance(labels, torch.Tensor):
+ labels = torch.tensor(labels, dtype=torch.int32)
+ if points.dim() == 2:
+ points = points.unsqueeze(0) # add batch dimension
+ if labels.dim() == 1:
+ labels = labels.unsqueeze(0) # add batch dimension
+
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
+ # along with the user-provided points (consistent with how SAM 2 is trained).
+ if box is not None:
+ if not clear_old_points:
+ raise ValueError(
+ "cannot add box without clearing old points, since "
+ "box prompt must be provided before any point prompt "
+ "(please use clear_old_points=True instead)"
+ )
+ if inference_state["tracking_has_started"]:
+ warnings.warn(
+ "You are adding a box after tracking starts. SAM 2 may not always be "
+ "able to incorporate a box prompt for *refinement*. If you intend to "
+ "use box prompt as an *initial* input before tracking, please call "
+ "'reset_state' on the inference state to restart from scratch.",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ if not isinstance(box, torch.Tensor):
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
+ box_coords = box.reshape(1, 2, 2)
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
+ box_labels = box_labels.reshape(1, 2)
+ points = torch.cat([box_coords, points], dim=1)
+ labels = torch.cat([box_labels, labels], dim=1)
+
+ if normalize_coords:
+ video_H = inference_state["video_height"]
+ video_W = inference_state["video_width"]
+ points = points / torch.tensor([video_W, video_H]).to(points.device)
+ # scale the (normalized) coordinates by the model's internal image size
+ points = points * self.image_size
+ points = points.to(inference_state["device"])
+ labels = labels.to(inference_state["device"])
+
+ if not clear_old_points:
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
+ else:
+ point_inputs = None
+ point_inputs = concat_points(point_inputs, points, labels)
+
+ point_inputs_per_frame[frame_idx] = point_inputs
+ mask_inputs_per_frame.pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
+ # whether to track in reverse time order
+ if is_init_cond_frame:
+ reverse = False
+ else:
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ # Get any previously predicted mask logits on this object and feed it along with
+ # the new clicks into the SAM mask decoder.
+ prev_sam_mask_logits = None
+ # lookup temporary output dict first, which contains the most recent output
+ # (if not found, then lookup conditioning and non-conditioning frame output)
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
+ if prev_out is None:
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
+ if prev_out is None:
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+
+ if prev_out is not None and prev_out["pred_masks"] is not None:
+ device = inference_state["device"]
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
+ current_out, _ = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=obj_output_dict, # run on the slice of a single object
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=point_inputs,
+ mask_inputs=None,
+ reverse=reverse,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+ # Resize the output mask to the original video resolution
+ obj_ids = inference_state["obj_ids"]
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ def add_new_points(self, *args, **kwargs):
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
+ return self.add_new_points_or_box(*args, **kwargs)
+
+ @torch.inference_mode()
+ def add_new_mask(
+ self,
+ inference_state,
+ frame_idx,
+ obj_id,
+ mask,
+ ):
+ """Add new mask to a frame."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+ if not isinstance(mask, torch.Tensor):
+ mask = torch.tensor(mask, dtype=torch.bool)
+ assert mask.dim() == 2
+ mask_H, mask_W = mask.shape
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
+
+ # resize the mask if it doesn't match the model's image size
+ if mask_H != self.image_size or mask_W != self.image_size:
+ mask_inputs = torch.nn.functional.interpolate(
+ mask_inputs_orig,
+ size=(self.image_size, self.image_size),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ mask_inputs = (mask_inputs >= 0.5).float()
+ else:
+ mask_inputs = mask_inputs_orig
+
+ mask_inputs_per_frame[frame_idx] = mask_inputs
+ point_inputs_per_frame.pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
+ # whether to track in reverse time order
+ if is_init_cond_frame:
+ reverse = False
+ else:
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ current_out, _ = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=obj_output_dict, # run on the slice of a single object
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=None,
+ mask_inputs=mask_inputs,
+ reverse=reverse,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ )
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+ # Resize the output mask to the original video resolution
+ obj_ids = inference_state["obj_ids"]
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
+ """
+ Resize the object scores to the original video resolution (video_res_masks)
+ and apply non-overlapping constraints for final output.
+ """
+ device = inference_state["device"]
+ video_H = inference_state["video_height"]
+ video_W = inference_state["video_width"]
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
+ if any_res_masks.shape[-2:] == (video_H, video_W):
+ video_res_masks = any_res_masks
+ else:
+ video_res_masks = torch.nn.functional.interpolate(
+ any_res_masks,
+ size=(video_H, video_W),
+ mode="bilinear",
+ align_corners=False,
+ )
+ if self.non_overlap_masks:
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
+ return any_res_masks, video_res_masks
+
+ def _consolidate_temp_output_across_obj(
+ self,
+ inference_state,
+ frame_idx,
+ is_cond,
+ run_mem_encoder,
+ consolidate_at_video_res=False,
+ ):
+ """
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
+ a frame into a single output for all objects, including
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
+ (if they don't exist in `output_dict_per_obj` for this frame);
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
+ on the object scores.
+ """
+ batch_size = self._get_obj_num(inference_state)
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ # Optionally, we allow consolidating the temporary outputs at the original
+ # video resolution (to provide a better editing experience for mask prompts).
+ if consolidate_at_video_res:
+ assert not run_mem_encoder, "memory encoder cannot run at video resolution"
+ consolidated_H = inference_state["video_height"]
+ consolidated_W = inference_state["video_width"]
+ consolidated_mask_key = "pred_masks_video_res"
+ else:
+ consolidated_H = consolidated_W = self.image_size // 4
+ consolidated_mask_key = "pred_masks"
+
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
+ # will be added when rerunning the memory encoder after applying non-overlapping
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
+ consolidated_out = {
+ "maskmem_features": None,
+ "maskmem_pos_enc": None,
+ consolidated_mask_key: torch.full(
+ size=(batch_size, 1, consolidated_H, consolidated_W),
+ fill_value=NO_OBJ_SCORE,
+ dtype=torch.float32,
+ device=inference_state["storage_device"],
+ ),
+ "obj_ptr": torch.full(
+ size=(batch_size, self.hidden_dim),
+ fill_value=NO_OBJ_SCORE,
+ dtype=torch.float32,
+ device=inference_state["device"],
+ ),
+ "object_score_logits": torch.full(
+ size=(batch_size, 1),
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
+ fill_value=10.0,
+ dtype=torch.float32,
+ device=inference_state["device"],
+ ),
+ }
+ empty_mask_ptr = None
+ for obj_idx in range(batch_size):
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
+ # we fall back and look up its previous output in "output_dict_per_obj".
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
+ # "output_dict_per_obj" to find a previous output for this object.
+ if out is None:
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
+ if out is None:
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
+ # placeholder above) and set its object pointer to be a dummy pointer.
+ if out is None:
+ # Fill in dummy object pointers for those objects without any inputs or
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
+ # i.e. when we need to build the memory for tracking).
+ if run_mem_encoder:
+ if empty_mask_ptr is None:
+ empty_mask_ptr = self._get_empty_mask_ptr(
+ inference_state, frame_idx
+ )
+ # fill object pointer with a dummy pointer (based on an empty mask)
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
+ continue
+ # Add the temporary object output mask to consolidated output mask
+ obj_mask = out["pred_masks"]
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
+ else:
+ # Resize first if temporary object mask has a different resolution
+ resized_obj_mask = torch.nn.functional.interpolate(
+ obj_mask,
+ size=consolidated_pred_masks.shape[-2:],
+ mode="bilinear",
+ align_corners=False,
+ )
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
+ consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
+ "object_score_logits"
+ ]
+
+ # Optionally, apply non-overlapping constraints on the consolidated scores
+ # and rerun the memory encoder
+ if run_mem_encoder:
+ device = inference_state["device"]
+ high_res_masks = torch.nn.functional.interpolate(
+ consolidated_out["pred_masks"].to(device, non_blocking=True),
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ if self.non_overlap_masks_for_mem_enc:
+ high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
+ inference_state=inference_state,
+ frame_idx=frame_idx,
+ batch_size=batch_size,
+ high_res_masks=high_res_masks,
+ object_score_logits=consolidated_out["object_score_logits"],
+ is_mask_from_pts=True, # these frames are what the user interacted with
+ )
+ consolidated_out["maskmem_features"] = maskmem_features
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
+
+ return consolidated_out
+
+ def _get_empty_mask_ptr(self, inference_state, frame_idx):
+ """Get a dummy object pointer based on an empty mask on the current frame."""
+ # A dummy (empty) mask with a single object
+ batch_size = 1
+ mask_inputs = torch.zeros(
+ (batch_size, 1, self.image_size, self.image_size),
+ dtype=torch.float32,
+ device=inference_state["device"],
+ )
+
+ # Retrieve correct image features
+ (
+ _,
+ _,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
+
+ # Feed the empty mask and image feature above to get a dummy object pointer
+ current_out = self.track_step(
+ frame_idx=frame_idx,
+ is_init_cond_frame=True,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=None,
+ mask_inputs=mask_inputs,
+ output_dict={},
+ num_frames=inference_state["num_frames"],
+ track_in_reverse=False,
+ run_mem_encoder=False,
+ prev_sam_mask_logits=None,
+ )
+ return current_out["obj_ptr"]
+
+ @torch.inference_mode()
+ def propagate_in_video_preflight(self, inference_state):
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
+ # Tracking has started and we don't allow adding new objects until session is reset.
+ inference_state["tracking_has_started"] = True
+ batch_size = self._get_obj_num(inference_state)
+
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
+ # add them into "output_dict".
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ output_dict = inference_state["output_dict"]
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
+ # temporary outputs have been added (either in this call or any previous calls
+ # to `propagate_in_video_preflight`).
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ for is_cond in [False, True]:
+ # Separately consolidate conditioning and non-conditioning temp outputs
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ # Find all the frames that contain temporary outputs for any objects
+ # (these should be the frames that have just received clicks for mask inputs
+ # via `add_new_points_or_box` or `add_new_mask`)
+ temp_frame_inds = set()
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
+ # consolidate the temporary output across all objects on this frame
+ for frame_idx in temp_frame_inds:
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
+ )
+ # merge them into "output_dict" and also create per-object slices
+ output_dict[storage_key][frame_idx] = consolidated_out
+ self._add_output_per_object(
+ inference_state, frame_idx, consolidated_out, storage_key
+ )
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
+ )
+ if clear_non_cond_mem:
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
+
+ # clear temporary outputs in `temp_output_dict_per_obj`
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
+ obj_temp_output_dict[storage_key].clear()
+
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
+ # output on the same frame in "non_cond_frame_outputs"
+ for frame_idx in output_dict["cond_frame_outputs"]:
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+ assert frame_idx in output_dict["cond_frame_outputs"]
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
+ # with either points or mask inputs (which should be true under a correct workflow).
+ all_consolidated_frame_inds = (
+ consolidated_frame_inds["cond_frame_outputs"]
+ | consolidated_frame_inds["non_cond_frame_outputs"]
+ )
+ input_frames_inds = set()
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
+ input_frames_inds.update(point_inputs_per_frame.keys())
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
+ input_frames_inds.update(mask_inputs_per_frame.keys())
+ assert all_consolidated_frame_inds == input_frames_inds
+
+ @torch.inference_mode()
+ def propagate_in_video(
+ self,
+ inference_state,
+ start_frame_idx=None,
+ max_frame_num_to_track=None,
+ reverse=False,
+ ):
+ """Propagate the input points across frames to track in the entire video."""
+ self.propagate_in_video_preflight(inference_state)
+
+ output_dict = inference_state["output_dict"]
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ obj_ids = inference_state["obj_ids"]
+ num_frames = inference_state["num_frames"]
+ batch_size = self._get_obj_num(inference_state)
+ if len(output_dict["cond_frame_outputs"]) == 0:
+ raise RuntimeError("No points are provided; please add points first")
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
+ )
+
+ # set start index, end index, and processing order
+ if start_frame_idx is None:
+ # default: start from the earliest frame with input points
+ start_frame_idx = min(output_dict["cond_frame_outputs"])
+ if max_frame_num_to_track is None:
+ # default: track all the frames in the video
+ max_frame_num_to_track = num_frames
+ if reverse:
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
+ if start_frame_idx > 0:
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
+ else:
+ processing_order = [] # skip reverse tracking if starting from frame 0
+ else:
+ end_frame_idx = min(
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
+ )
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
+
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
+ # We skip those frames already in consolidated outputs (these are frames
+ # that received input clicks or mask). Note that we cannot directly run
+ # batched forward on them via `_run_single_frame_inference` because the
+ # number of clicks on each object might be different.
+ if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+ storage_key = "cond_frame_outputs"
+ current_out = output_dict[storage_key][frame_idx]
+ pred_masks = current_out["pred_masks"]
+ if clear_non_cond_mem:
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
+ elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
+ storage_key = "non_cond_frame_outputs"
+ current_out = output_dict[storage_key][frame_idx]
+ pred_masks = current_out["pred_masks"]
+ else:
+ storage_key = "non_cond_frame_outputs"
+ current_out, pred_masks = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=output_dict,
+ frame_idx=frame_idx,
+ batch_size=batch_size,
+ is_init_cond_frame=False,
+ point_inputs=None,
+ mask_inputs=None,
+ reverse=reverse,
+ run_mem_encoder=True,
+ )
+ output_dict[storage_key][frame_idx] = current_out
+ # Create slices of per-object outputs for subsequent interaction with each
+ # individual object after tracking.
+ self._add_output_per_object(
+ inference_state, frame_idx, current_out, storage_key
+ )
+ inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
+
+ # Resize the output mask to the original video resolution (we directly use
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, pred_masks
+ )
+ yield frame_idx, obj_ids, video_res_masks
+
+ def _add_output_per_object(
+ self, inference_state, frame_idx, current_out, storage_key
+ ):
+ """
+ Split a multi-object output into per-object output slices and add them into
+ `output_dict_per_obj`. The resulting slices share the same tensor storage.
+ """
+ maskmem_features = current_out["maskmem_features"]
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
+
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
+
+ output_dict_per_obj = inference_state["output_dict_per_obj"]
+ for obj_idx, obj_output_dict in output_dict_per_obj.items():
+ obj_slice = slice(obj_idx, obj_idx + 1)
+ obj_out = {
+ "maskmem_features": None,
+ "maskmem_pos_enc": None,
+ "pred_masks": current_out["pred_masks"][obj_slice],
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
+ "object_score_logits": current_out["object_score_logits"][obj_slice],
+ }
+ if maskmem_features is not None:
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
+ if maskmem_pos_enc is not None:
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
+ obj_output_dict[storage_key][frame_idx] = obj_out
+
+ @torch.inference_mode()
+ def clear_all_prompts_in_frame(
+ self, inference_state, frame_idx, obj_id, need_output=True
+ ):
+ """Remove all input points or mask in a specific frame for a given object."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+
+ # Clear the conditioning information on the given frame
+ inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
+ inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
+
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
+ temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
+
+ # Check and see if there are still any inputs left on this frame
+ batch_size = self._get_obj_num(inference_state)
+ frame_has_input = False
+ for obj_idx2 in range(batch_size):
+ if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
+ frame_has_input = True
+ break
+ if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
+ frame_has_input = True
+ break
+
+ # If this frame has no remaining inputs for any objects, we further clear its
+ # conditioning frame status
+ if not frame_has_input:
+ output_dict = inference_state["output_dict"]
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+ # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
+ out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
+ if out is not None:
+ # The frame is not a conditioning frame anymore since it's not receiving inputs,
+ # so we "downgrade" its output (if exists) to a non-conditioning frame output.
+ output_dict["non_cond_frame_outputs"][frame_idx] = out
+ inference_state["frames_already_tracked"].pop(frame_idx, None)
+ # Similarly, do it for the sliced output on each object.
+ for obj_idx2 in range(batch_size):
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
+ obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
+ if obj_out is not None:
+ obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
+
+ # If all the conditioning frames have been removed, we also clear the tracking outputs
+ if len(output_dict["cond_frame_outputs"]) == 0:
+ self._reset_tracking_results(inference_state)
+
+ if not need_output:
+ return
+ # Finally, output updated masks per object (after removing the inputs above)
+ obj_ids = inference_state["obj_ids"]
+ is_cond = any(
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
+ )
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ @torch.inference_mode()
+ def reset_state(self, inference_state):
+ """Remove all input points or mask in all frames throughout the video."""
+ self._reset_tracking_results(inference_state)
+ # Remove all object ids
+ inference_state["obj_id_to_idx"].clear()
+ inference_state["obj_idx_to_id"].clear()
+ inference_state["obj_ids"].clear()
+ inference_state["point_inputs_per_obj"].clear()
+ inference_state["mask_inputs_per_obj"].clear()
+ inference_state["output_dict_per_obj"].clear()
+ inference_state["temp_output_dict_per_obj"].clear()
+
+ def _reset_tracking_results(self, inference_state):
+ """Reset all tracking inputs and results across the videos."""
+ for v in inference_state["point_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["mask_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ for v in inference_state["temp_output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
+ inference_state["tracking_has_started"] = False
+ inference_state["frames_already_tracked"].clear()
+
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
+ """Compute the image features on a given frame."""
+ # Look up in the cache first
+ image, backbone_out = inference_state["cached_features"].get(
+ frame_idx, (None, None)
+ )
+ if backbone_out is None:
+ # Cache miss -- we will run inference on a single image
+ device = inference_state["device"]
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
+ backbone_out = self.forward_image(image)
+ # Cache the most recent frame's feature (for repeated interactions with
+ # a frame; we can use an LRU cache for more frames in the future).
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
+
+ # expand the features to have the same dimension as the number of objects
+ expanded_image = image.expand(batch_size, -1, -1, -1)
+ expanded_backbone_out = {
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
+ }
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
+ batch_size, -1, -1, -1
+ )
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
+ pos = pos.expand(batch_size, -1, -1, -1)
+ expanded_backbone_out["vision_pos_enc"][i] = pos
+
+ features = self._prepare_backbone_features(expanded_backbone_out)
+ features = (expanded_image,) + features
+ return features
+
+ def _run_single_frame_inference(
+ self,
+ inference_state,
+ output_dict,
+ frame_idx,
+ batch_size,
+ is_init_cond_frame,
+ point_inputs,
+ mask_inputs,
+ reverse,
+ run_mem_encoder,
+ prev_sam_mask_logits=None,
+ ):
+ """Run tracking on a single frame based on current inputs and previous memory."""
+ # Retrieve correct image features
+ (
+ _,
+ _,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
+
+ # point and mask should not appear as input simultaneously on the same frame
+ assert point_inputs is None or mask_inputs is None
+ current_out = self.track_step(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ output_dict=output_dict,
+ num_frames=inference_state["num_frames"],
+ track_in_reverse=reverse,
+ run_mem_encoder=run_mem_encoder,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+
+ # optionally offload the output to CPU memory to save GPU space
+ storage_device = inference_state["storage_device"]
+ maskmem_features = current_out["maskmem_features"]
+ if maskmem_features is not None:
+ maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+ pred_masks_gpu = current_out["pred_masks"]
+ # potentially fill holes in the predicted masks
+ if self.fill_hole_area > 0:
+ pred_masks_gpu = fill_holes_in_mask_scores(
+ pred_masks_gpu, self.fill_hole_area
+ )
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
+ obj_ptr = current_out["obj_ptr"]
+ object_score_logits = current_out["object_score_logits"]
+ # make a compact version of this frame's output to reduce the state size
+ compact_current_out = {
+ "maskmem_features": maskmem_features,
+ "maskmem_pos_enc": maskmem_pos_enc,
+ "pred_masks": pred_masks,
+ "obj_ptr": obj_ptr,
+ "object_score_logits": object_score_logits,
+ }
+ return compact_current_out, pred_masks_gpu
+
+ def _run_memory_encoder(
+ self,
+ inference_state,
+ frame_idx,
+ batch_size,
+ high_res_masks,
+ object_score_logits,
+ is_mask_from_pts,
+ ):
+ """
+ Run the memory encoder on `high_res_masks`. This is usually after applying
+ non-overlapping constraints to object scores. Since their scores changed, their
+ memory also need to be computed again with the memory encoder.
+ """
+ # Retrieve correct image features
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
+ inference_state, frame_idx, batch_size
+ )
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks,
+ object_score_logits=object_score_logits,
+ is_mask_from_pts=is_mask_from_pts,
+ )
+
+ # optionally offload the output to CPU memory to save GPU space
+ storage_device = inference_state["storage_device"]
+ maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
+ )
+ return maskmem_features, maskmem_pos_enc
+
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
+ """
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
+ a constant in the inference session to reduce session storage size.
+ """
+ model_constants = inference_state["constants"]
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
+ if out_maskmem_pos_enc is not None:
+ if "maskmem_pos_enc" not in model_constants:
+ assert isinstance(out_maskmem_pos_enc, list)
+ # only take the slice for one object, since it's same across objects
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
+ # expand the cached maskmem_pos_enc to the actual batch size
+ batch_size = out_maskmem_pos_enc[0].size(0)
+ expanded_maskmem_pos_enc = [
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
+ ]
+ else:
+ expanded_maskmem_pos_enc = None
+ return expanded_maskmem_pos_enc
+
+ @torch.inference_mode()
+ def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
+ """
+ Remove an object id from the tracking state. If strict is True, we check whether
+ the object id actually exists and raise an error if it doesn't exist.
+ """
+ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
+ updated_frames = []
+ # Check whether this object_id to remove actually exists and possibly raise an error.
+ if old_obj_idx_to_rm is None:
+ if not strict:
+ return inference_state["obj_ids"], updated_frames
+ raise RuntimeError(
+ f"Cannot remove object id {obj_id} as it doesn't exist. "
+ f"All existing object ids: {inference_state['obj_ids']}."
+ )
+
+ # If this is the only remaining object id, we simply reset the state.
+ if len(inference_state["obj_id_to_idx"]) == 1:
+ self.reset_state(inference_state)
+ return inference_state["obj_ids"], updated_frames
+
+ # There are still remaining objects after removing this object id. In this case,
+ # we need to delete the object storage from inference state tensors.
+ # Step 0: clear the input on those frames where this object id has point or mask input
+ # (note that this step is required as it might downgrade conditioning frames to
+ # non-conditioning ones)
+ obj_input_frames_inds = set()
+ obj_input_frames_inds.update(
+ inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
+ )
+ obj_input_frames_inds.update(
+ inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
+ )
+ for frame_idx in obj_input_frames_inds:
+ self.clear_all_prompts_in_frame(
+ inference_state, frame_idx, obj_id, need_output=False
+ )
+
+ # Step 1: Update the object id mapping (note that it must be done after Step 0,
+ # since Step 0 still requires the old object id mappings in inference_state)
+ old_obj_ids = inference_state["obj_ids"]
+ old_obj_inds = list(range(len(old_obj_ids)))
+ remain_old_obj_inds = old_obj_inds.copy()
+ remain_old_obj_inds.remove(old_obj_idx_to_rm)
+ new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
+ new_obj_inds = list(range(len(new_obj_ids)))
+ # build new mappings
+ old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
+ inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
+ inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
+ inference_state["obj_ids"] = new_obj_ids
+
+ # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
+ # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
+ # it's already handled in Step 0)
+ def _map_keys(container):
+ new_kvs = []
+ for k in old_obj_inds:
+ v = container.pop(k)
+ if k in old_idx_to_new_idx:
+ new_kvs.append((old_idx_to_new_idx[k], v))
+ container.update(new_kvs)
+
+ _map_keys(inference_state["point_inputs_per_obj"])
+ _map_keys(inference_state["mask_inputs_per_obj"])
+ _map_keys(inference_state["output_dict_per_obj"])
+ _map_keys(inference_state["temp_output_dict_per_obj"])
+
+ # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
+ def _slice_state(output_dict, storage_key):
+ for frame_idx, out in output_dict[storage_key].items():
+ out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
+ out["maskmem_pos_enc"] = [
+ x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
+ ]
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
+ out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
+ out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
+ out["object_score_logits"] = out["object_score_logits"][
+ remain_old_obj_inds
+ ]
+ # also update the per-object slices
+ self._add_output_per_object(
+ inference_state, frame_idx, out, storage_key
+ )
+
+ _slice_state(inference_state["output_dict"], "cond_frame_outputs")
+ _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
+
+ # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
+ # could show an updated mask for objects previously occluded by the object being removed
+ if need_output:
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ for frame_idx in obj_input_frames_inds:
+ is_cond = any(
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
+ )
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ updated_frames.append((frame_idx, video_res_masks))
+
+ return inference_state["obj_ids"], updated_frames
+
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
+ """
+ Remove the non-conditioning memory around the input frame. When users provide
+ correction clicks, the surrounding frames' non-conditioning memories can still
+ contain outdated object appearance information and could confuse the model.
+
+ This method clears those non-conditioning memories surrounding the interacted
+ frame to avoid giving the model both old and new information about the object.
+ """
+ r = self.memory_temporal_stride_for_eval
+ frame_idx_begin = frame_idx - r * self.num_maskmem
+ frame_idx_end = frame_idx + r * self.num_maskmem
+ output_dict = inference_state["output_dict"]
+ non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
+ for t in range(frame_idx_begin, frame_idx_end + 1):
+ non_cond_frame_outputs.pop(t, None)
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
diff --git a/phantom/submodules/sam2/sam2/utils/__init__.py b/phantom/submodules/sam2/sam2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/phantom/submodules/sam2/sam2/utils/amg.py b/phantom/submodules/sam2/sam2/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..986842960cf5deca00614b7b1cde1ab77dad7e6e
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/utils/amg.py
@@ -0,0 +1,348 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+import numpy as np
+import torch
+
+# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
+
+
+class MaskData:
+ """
+ A structure for storing masks and their related data in batched format.
+ Implements basic filtering and concatenation.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ for v in kwargs.values():
+ assert isinstance(
+ v, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats = dict(**kwargs)
+
+ def __setitem__(self, key: str, item: Any) -> None:
+ assert isinstance(
+ item, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats[key] = item
+
+ def __delitem__(self, key: str) -> None:
+ del self._stats[key]
+
+ def __getitem__(self, key: str) -> Any:
+ return self._stats[key]
+
+ def items(self) -> ItemsView[str, Any]:
+ return self._stats.items()
+
+ def filter(self, keep: torch.Tensor) -> None:
+ for k, v in self._stats.items():
+ if v is None:
+ self._stats[k] = None
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = v[keep.detach().cpu().numpy()]
+ elif isinstance(v, list) and keep.dtype == torch.bool:
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+ elif isinstance(v, list):
+ self._stats[k] = [v[i] for i in keep]
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def cat(self, new_stats: "MaskData") -> None:
+ for k, v in new_stats.items():
+ if k not in self._stats or self._stats[k] is None:
+ self._stats[k] = deepcopy(v)
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+ elif isinstance(v, list):
+ self._stats[k] = self._stats[k] + deepcopy(v)
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def to_numpy(self) -> None:
+ for k, v in self._stats.items():
+ if isinstance(v, torch.Tensor):
+ self._stats[k] = v.float().detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+ box_xywh = deepcopy(box_xyxy)
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
+ return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ assert len(args) > 0 and all(
+ len(a) == len(args[0]) for a in args
+ ), "Batched iteration must have inputs of all the same size."
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+ """
+ Encodes masks to an uncompressed RLE, in the format expected by
+ pycoco tools.
+ """
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat(
+ [
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ ]
+ )
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({"size": [h, w], "counts": counts})
+ return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ h, w = rle["size"]
+ mask = np.empty(h * w, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity ^= True
+ mask = mask.reshape(w, h)
+ return mask.transpose() # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+ return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+ """
+ Computes the stability score for a batch of masks. The stability
+ score is the IoU between the binary masks obtained by thresholding
+ the predicted mask logits at high and low values.
+ """
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ unions = (
+ (masks > (mask_threshold - threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def build_all_layer_point_grids(
+ n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+ """Generates point grids for all crop layers."""
+ points_by_layer = []
+ for i in range(n_layers + 1):
+ n_points = int(n_per_side / (scale_per_layer**i))
+ points_by_layer.append(build_point_grid(n_points))
+ return points_by_layer
+
+
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer
+ has (2**i)**2 boxes for the ith layer.
+ """
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+
+ def crop_len(orig_len, n_crops, overlap):
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return points + offset
+
+
+def uncrop_masks(
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+ mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+ """
+ Removes small disconnected regions and holes in a mask. Returns the
+ mask and an indicator of if the mask has been modified.
+ """
+ import cv2 # type: ignore
+
+ assert mode in ["holes", "islands"]
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ # If every region is below threshold, keep largest
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+ from pycocotools import mask as mask_utils # type: ignore
+
+ h, w = uncompressed_rle["size"]
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
+ return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ if len(shape) > 2:
+ masks = masks.flatten(0, -3)
+ else:
+ masks = masks.unsqueeze(0)
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ if len(shape) > 2:
+ out = out.reshape(*shape[:-2], 4)
+ else:
+ out = out[0]
+
+ return out
diff --git a/phantom/submodules/sam2/sam2/utils/misc.py b/phantom/submodules/sam2/sam2/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..b65ee825732ff85137805be650edd4cbe8e6f6d4
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/utils/misc.py
@@ -0,0 +1,349 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import warnings
+from threading import Thread
+
+import numpy as np
+import torch
+from PIL import Image
+from tqdm import tqdm
+
+
+def get_sdpa_settings():
+ if torch.cuda.is_available():
+ old_gpu = torch.cuda.get_device_properties(0).major < 7
+ # only use Flash Attention on Ampere (8.0) or newer GPUs
+ use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
+ if not use_flash_attn:
+ warnings.warn(
+ "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
+ # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
+ pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
+ if pytorch_version < (2, 2):
+ warnings.warn(
+ f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
+ "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
+ else:
+ old_gpu = True
+ use_flash_attn = False
+ math_kernel_on = True
+
+ return old_gpu, use_flash_attn, math_kernel_on
+
+
+def get_connected_components(mask):
+ """
+ Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
+
+ Inputs:
+ - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
+ background.
+
+ Outputs:
+ - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
+ for foreground pixels and 0 for background pixels.
+ - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
+ components for foreground pixels and 0 for background pixels.
+ """
+ from sam2 import _C
+
+ return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
+
+
+def mask_to_box(masks: torch.Tensor):
+ """
+ compute bounding box given an input mask
+
+ Inputs:
+ - masks: [B, 1, H, W] masks, dtype=torch.Tensor
+
+ Returns:
+ - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
+ """
+ B, _, h, w = masks.shape
+ device = masks.device
+ xs = torch.arange(w, device=device, dtype=torch.int32)
+ ys = torch.arange(h, device=device, dtype=torch.int32)
+ grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
+ grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
+ grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
+ min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
+ max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
+ min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
+ max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
+ bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
+
+ return bbox_coords
+
+
+def _load_img_as_tensor(img_path, image_size):
+ img_pil = Image.open(img_path)
+ img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
+ if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
+ img_np = img_np / 255.0
+ else:
+ raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
+ img = torch.from_numpy(img_np).permute(2, 0, 1)
+ video_width, video_height = img_pil.size # the original video size
+ return img, video_height, video_width
+
+
+class AsyncVideoFrameLoader:
+ """
+ A list of video frames to be load asynchronously without blocking session start.
+ """
+
+ def __init__(
+ self,
+ img_paths,
+ image_size,
+ offload_video_to_cpu,
+ img_mean,
+ img_std,
+ compute_device,
+ ):
+ self.img_paths = img_paths
+ self.image_size = image_size
+ self.offload_video_to_cpu = offload_video_to_cpu
+ self.img_mean = img_mean
+ self.img_std = img_std
+ # items in `self.images` will be loaded asynchronously
+ self.images = [None] * len(img_paths)
+ # catch and raise any exceptions in the async loading thread
+ self.exception = None
+ # video_height and video_width be filled when loading the first image
+ self.video_height = None
+ self.video_width = None
+ self.compute_device = compute_device
+
+ # load the first frame to fill video_height and video_width and also
+ # to cache it (since it's most likely where the user will click)
+ self.__getitem__(0)
+
+ # load the rest of frames asynchronously without blocking the session start
+ def _load_frames():
+ try:
+ for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
+ self.__getitem__(n)
+ except Exception as e:
+ self.exception = e
+
+ self.thread = Thread(target=_load_frames, daemon=True)
+ self.thread.start()
+
+ def __getitem__(self, index):
+ if self.exception is not None:
+ raise RuntimeError("Failure in frame loading thread") from self.exception
+
+ img = self.images[index]
+ if img is not None:
+ return img
+
+ img, video_height, video_width = _load_img_as_tensor(
+ self.img_paths[index], self.image_size
+ )
+ self.video_height = video_height
+ self.video_width = video_width
+ # normalize by mean and std
+ img -= self.img_mean
+ img /= self.img_std
+ if not self.offload_video_to_cpu:
+ img = img.to(self.compute_device, non_blocking=True)
+ self.images[index] = img
+ return img
+
+ def __len__(self):
+ return len(self.images)
+
+
+def load_video_frames(
+ video_path,
+ image_size,
+ offload_video_to_cpu,
+ img_mean=(0.485, 0.456, 0.406),
+ img_std=(0.229, 0.224, 0.225),
+ async_loading_frames=False,
+ compute_device=torch.device("cuda"),
+):
+ """
+ Load the video frames from video_path. The frames are resized to image_size as in
+ the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
+ """
+ is_bytes = isinstance(video_path, bytes)
+ is_str = isinstance(video_path, str)
+ is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"]
+ if is_bytes or is_mp4_path:
+ return load_video_frames_from_video_file(
+ video_path=video_path,
+ image_size=image_size,
+ offload_video_to_cpu=offload_video_to_cpu,
+ img_mean=img_mean,
+ img_std=img_std,
+ compute_device=compute_device,
+ )
+ elif is_str and os.path.isdir(video_path):
+ return load_video_frames_from_jpg_images(
+ video_path=video_path,
+ image_size=image_size,
+ offload_video_to_cpu=offload_video_to_cpu,
+ img_mean=img_mean,
+ img_std=img_std,
+ async_loading_frames=async_loading_frames,
+ compute_device=compute_device,
+ )
+ else:
+ raise NotImplementedError(
+ "Only MP4 video and JPEG folder are supported at this moment"
+ )
+
+
+def load_video_frames_from_jpg_images(
+ video_path,
+ image_size,
+ offload_video_to_cpu,
+ img_mean=(0.485, 0.456, 0.406),
+ img_std=(0.229, 0.224, 0.225),
+ async_loading_frames=False,
+ compute_device=torch.device("cuda"),
+):
+ """
+ Load the video frames from a directory of JPEG files (".jpg" format).
+
+ The frames are resized to image_size x image_size and are loaded to GPU if
+ `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
+
+ You can load a frame asynchronously by setting `async_loading_frames` to `True`.
+ """
+ if isinstance(video_path, str) and os.path.isdir(video_path):
+ jpg_folder = video_path
+ else:
+ raise NotImplementedError(
+ "Only JPEG frames are supported at this moment. For video files, you may use "
+ "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
+ "```\n"
+ "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n"
+ "```\n"
+ "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
+ "ffmpeg to start the JPEG file from 00000.jpg."
+ )
+
+ frame_names = [
+ p
+ for p in os.listdir(jpg_folder)
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
+ ]
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
+ num_frames = len(frame_names)
+ if num_frames == 0:
+ raise RuntimeError(f"no images found in {jpg_folder}")
+ img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+
+ if async_loading_frames:
+ lazy_images = AsyncVideoFrameLoader(
+ img_paths,
+ image_size,
+ offload_video_to_cpu,
+ img_mean,
+ img_std,
+ compute_device,
+ )
+ return lazy_images, lazy_images.video_height, lazy_images.video_width
+
+ images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
+ for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
+ images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
+ if not offload_video_to_cpu:
+ images = images.to(compute_device)
+ img_mean = img_mean.to(compute_device)
+ img_std = img_std.to(compute_device)
+ # normalize by mean and std
+ images -= img_mean
+ images /= img_std
+ return images, video_height, video_width
+
+
+def load_video_frames_from_video_file(
+ video_path,
+ image_size,
+ offload_video_to_cpu,
+ img_mean=(0.485, 0.456, 0.406),
+ img_std=(0.229, 0.224, 0.225),
+ compute_device=torch.device("cuda"),
+):
+ """Load the video frames from a video file."""
+ import decord
+
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+ # Get the original video height and width
+ decord.bridge.set_bridge("torch")
+ video_height, video_width, _ = decord.VideoReader(video_path).next().shape
+ # Iterate over all frames in the video
+ images = []
+ for frame in decord.VideoReader(video_path, width=image_size, height=image_size):
+ images.append(frame.permute(2, 0, 1))
+
+ images = torch.stack(images, dim=0).float() / 255.0
+ if not offload_video_to_cpu:
+ images = images.to(compute_device)
+ img_mean = img_mean.to(compute_device)
+ img_std = img_std.to(compute_device)
+ # normalize by mean and std
+ images -= img_mean
+ images /= img_std
+ return images, video_height, video_width
+
+
+def fill_holes_in_mask_scores(mask, max_area):
+ """
+ A post processor to fill small holes in mask scores with area under `max_area`.
+ """
+ # Holes are those connected components in background with area <= self.max_area
+ # (background regions are those with mask scores <= 0)
+ assert max_area > 0, "max_area must be positive"
+
+ input_mask = mask
+ try:
+ labels, areas = get_connected_components(mask <= 0)
+ is_hole = (labels > 0) & (areas <= max_area)
+ # We fill holes with a small positive mask score (0.1) to change them to foreground.
+ mask = torch.where(is_hole, 0.1, mask)
+ except Exception as e:
+ # Skip the post-processing step on removing small holes if the CUDA kernel fails
+ warnings.warn(
+ f"{e}\n\nSkipping the post-processing step due to the error above. You can "
+ "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
+ "functionality may be limited (which doesn't affect the results in most cases; see "
+ "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ mask = input_mask
+
+ return mask
+
+
+def concat_points(old_point_inputs, new_points, new_labels):
+ """Add new points and labels to previous point inputs (add at the end)."""
+ if old_point_inputs is None:
+ points, labels = new_points, new_labels
+ else:
+ points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
+ labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
+
+ return {"point_coords": points, "point_labels": labels}
diff --git a/phantom/submodules/sam2/sam2/utils/transforms.py b/phantom/submodules/sam2/sam2/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc17bebfab104b659c5469e8434cf357ae7e24b6
--- /dev/null
+++ b/phantom/submodules/sam2/sam2/utils/transforms.py
@@ -0,0 +1,118 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Normalize, Resize, ToTensor
+
+
+class SAM2Transforms(nn.Module):
+ def __init__(
+ self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
+ ):
+ """
+ Transforms for SAM2.
+ """
+ super().__init__()
+ self.resolution = resolution
+ self.mask_threshold = mask_threshold
+ self.max_hole_area = max_hole_area
+ self.max_sprinkle_area = max_sprinkle_area
+ self.mean = [0.485, 0.456, 0.406]
+ self.std = [0.229, 0.224, 0.225]
+ self.to_tensor = ToTensor()
+ self.transforms = torch.jit.script(
+ nn.Sequential(
+ Resize((self.resolution, self.resolution)),
+ Normalize(self.mean, self.std),
+ )
+ )
+
+ def __call__(self, x):
+ x = self.to_tensor(x)
+ return self.transforms(x)
+
+ def forward_batch(self, img_list):
+ img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
+ img_batch = torch.stack(img_batch, dim=0)
+ return img_batch
+
+ def transform_coords(
+ self, coords: torch.Tensor, normalize=False, orig_hw=None
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
+ If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
+
+ Returns
+ Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
+ """
+ if normalize:
+ assert orig_hw is not None
+ h, w = orig_hw
+ coords = coords.clone()
+ coords[..., 0] = coords[..., 0] / w
+ coords[..., 1] = coords[..., 1] / h
+
+ coords = coords * self.resolution # unnormalize coords
+ return coords
+
+ def transform_boxes(
+ self, boxes: torch.Tensor, normalize=False, orig_hw=None
+ ) -> torch.Tensor:
+ """
+ Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
+ if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
+ """
+ boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
+ return boxes
+
+ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
+ """
+ Perform PostProcessing on output masks.
+ """
+ from sam2.utils.misc import get_connected_components
+
+ masks = masks.float()
+ input_masks = masks
+ mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
+ try:
+ if self.max_hole_area > 0:
+ # Holes are those connected components in background with area <= self.fill_hole_area
+ # (background regions are those with mask scores <= self.mask_threshold)
+ labels, areas = get_connected_components(
+ mask_flat <= self.mask_threshold
+ )
+ is_hole = (labels > 0) & (areas <= self.max_hole_area)
+ is_hole = is_hole.reshape_as(masks)
+ # We fill holes with a small positive mask score (10.0) to change them to foreground.
+ masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
+
+ if self.max_sprinkle_area > 0:
+ labels, areas = get_connected_components(
+ mask_flat > self.mask_threshold
+ )
+ is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
+ is_hole = is_hole.reshape_as(masks)
+ # We fill holes with negative mask score (-10.0) to change them to background.
+ masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
+ except Exception as e:
+ # Skip the post-processing step if the CUDA kernel fails
+ warnings.warn(
+ f"{e}\n\nSkipping the post-processing step due to the error above. You can "
+ "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
+ "functionality may be limited (which doesn't affect the results in most cases; see "
+ "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ masks = input_masks
+
+ masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
+ return masks
diff --git a/phantom/submodules/sam2/setup.py b/phantom/submodules/sam2/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..78a634cddb19615c45601681ffbcd1f29af66f47
--- /dev/null
+++ b/phantom/submodules/sam2/setup.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+
+from setuptools import find_packages, setup
+
+# Package metadata
+NAME = "SAM-2"
+VERSION = "1.0"
+DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
+URL = "https://github.com/facebookresearch/sam2"
+AUTHOR = "Meta AI"
+AUTHOR_EMAIL = "segment-anything@meta.com"
+LICENSE = "Apache 2.0"
+
+# Read the contents of README file
+with open("README.md", "r", encoding="utf-8") as f:
+ LONG_DESCRIPTION = f.read()
+
+# Required dependencies
+REQUIRED_PACKAGES = [
+ "torch>=2.5.1",
+ "torchvision>=0.20.1",
+ "numpy>=1.24.4",
+ "tqdm>=4.66.1",
+ "hydra-core>=1.3.2",
+ "iopath>=0.1.10",
+ "pillow>=9.4.0",
+]
+
+EXTRA_PACKAGES = {
+ "notebooks": [
+ "matplotlib>=3.9.1",
+ "jupyter>=1.0.0",
+ "opencv-python>=4.7.0",
+ "eva-decord>=0.6.1",
+ ],
+ "interactive-demo": [
+ "Flask>=3.0.3",
+ "Flask-Cors>=5.0.0",
+ "av>=13.0.0",
+ "dataclasses-json>=0.6.7",
+ "eva-decord>=0.6.1",
+ "gunicorn>=23.0.0",
+ "imagesize>=1.4.1",
+ "pycocotools>=2.0.8",
+ "strawberry-graphql>=0.243.0",
+ ],
+ "dev": [
+ "black==24.2.0",
+ "usort==1.0.2",
+ "ufmt==2.0.0b2",
+ "fvcore>=0.1.5.post20221221",
+ "pandas>=2.2.2",
+ "scikit-image>=0.24.0",
+ "tensorboard>=2.17.0",
+ "pycocotools>=2.0.8",
+ "tensordict>=0.6.0",
+ "opencv-python>=4.7.0",
+ "submitit>=1.5.1",
+ ],
+}
+
+# By default, we also build the SAM 2 CUDA extension.
+# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
+BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
+# By default, we allow SAM 2 installation to proceed even with build errors.
+# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
+BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
+
+# Catch and skip errors during extension building and print a warning message
+# (note that this message only shows up under verbose build mode
+# "pip install -v -e ." or "python setup.py build_ext -v")
+CUDA_ERROR_MSG = (
+ "{}\n\n"
+ "Failed to build the SAM 2 CUDA extension due to the error above. "
+ "You can still use SAM 2 and it's OK to ignore the error above, although some "
+ "post-processing functionality may be limited (which doesn't affect the results in most cases; "
+ "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n"
+)
+
+
+def get_extensions():
+ if not BUILD_CUDA:
+ return []
+
+ try:
+ from torch.utils.cpp_extension import CUDAExtension
+
+ srcs = ["sam2/csrc/connected_components.cu"]
+ compile_args = {
+ "cxx": [],
+ "nvcc": [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ],
+ }
+ ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
+ except Exception as e:
+ if BUILD_ALLOW_ERRORS:
+ print(CUDA_ERROR_MSG.format(e))
+ ext_modules = []
+ else:
+ raise e
+
+ return ext_modules
+
+
+try:
+ from torch.utils.cpp_extension import BuildExtension
+
+ class BuildExtensionIgnoreErrors(BuildExtension):
+
+ def finalize_options(self):
+ try:
+ super().finalize_options()
+ except Exception as e:
+ print(CUDA_ERROR_MSG.format(e))
+ self.extensions = []
+
+ def build_extensions(self):
+ try:
+ super().build_extensions()
+ except Exception as e:
+ print(CUDA_ERROR_MSG.format(e))
+ self.extensions = []
+
+ def get_ext_filename(self, ext_name):
+ try:
+ return super().get_ext_filename(ext_name)
+ except Exception as e:
+ print(CUDA_ERROR_MSG.format(e))
+ self.extensions = []
+ return "_C.so"
+
+ cmdclass = {
+ "build_ext": (
+ BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
+ if BUILD_ALLOW_ERRORS
+ else BuildExtension.with_options(no_python_abi_suffix=True)
+ )
+ }
+except Exception as e:
+ cmdclass = {}
+ if BUILD_ALLOW_ERRORS:
+ print(CUDA_ERROR_MSG.format(e))
+ else:
+ raise e
+
+
+# Setup configuration
+setup(
+ name=NAME,
+ version=VERSION,
+ description=DESCRIPTION,
+ long_description=LONG_DESCRIPTION,
+ long_description_content_type="text/markdown",
+ url=URL,
+ author=AUTHOR,
+ author_email=AUTHOR_EMAIL,
+ license=LICENSE,
+ packages=find_packages(exclude="notebooks"),
+ include_package_data=True,
+ install_requires=REQUIRED_PACKAGES,
+ extras_require=EXTRA_PACKAGES,
+ python_requires=">=3.10.0",
+ ext_modules=get_extensions(),
+ cmdclass=cmdclass,
+)
diff --git a/phantom/submodules/sam2/tools/README.md b/phantom/submodules/sam2/tools/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1dd0e8a754f4bf27ee321084076f3ebdb2285450
--- /dev/null
+++ b/phantom/submodules/sam2/tools/README.md
@@ -0,0 +1,36 @@
+## SAM 2 toolkits
+
+This directory provides toolkits for additional SAM 2 use cases.
+
+### Semi-supervised VOS inference
+
+The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset.
+
+After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`.
+```bash
+python ./tools/vos_inference.py \
+ --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \
+ --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
+ --base_video_dir /path-to-davis-2017/JPEGImages/480p \
+ --input_mask_dir /path-to-davis-2017/Annotations/480p \
+ --video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \
+ --output_mask_dir ./outputs/davis_2017_pred_pngs
+```
+(replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset)
+
+To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag.
+```bash
+python ./tools/vos_inference.py \
+ --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \
+ --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
+ --base_video_dir /path-to-sav-val/JPEGImages_24fps \
+ --input_mask_dir /path-to-sav-val/Annotations_6fps \
+ --video_list_file /path-to-sav-val/sav_val.txt \
+ --per_obj_png_file \
+ --output_mask_dir ./outputs/sav_val_pred_pngs
+```
+(replace `/path-to-sav-val` with the path to SA-V val)
+
+Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above.
+
+Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**.
diff --git a/phantom/submodules/sam2/tools/vos_inference.py b/phantom/submodules/sam2/tools/vos_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef3e8c6740541f342cfbbe0fa8ad80e47caf4ac9
--- /dev/null
+++ b/phantom/submodules/sam2/tools/vos_inference.py
@@ -0,0 +1,507 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+from collections import defaultdict
+
+import numpy as np
+import torch
+from PIL import Image
+from sam2.build_sam import build_sam2_video_predictor
+
+
+# the PNG palette for DAVIS 2017 dataset
+DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0"
+
+
+def load_ann_png(path):
+ """Load a PNG file as a mask and its palette."""
+ mask = Image.open(path)
+ palette = mask.getpalette()
+ mask = np.array(mask).astype(np.uint8)
+ return mask, palette
+
+
+def save_ann_png(path, mask, palette):
+ """Save a mask as a PNG file with the given palette."""
+ assert mask.dtype == np.uint8
+ assert mask.ndim == 2
+ output_mask = Image.fromarray(mask)
+ output_mask.putpalette(palette)
+ output_mask.save(path)
+
+
+def get_per_obj_mask(mask):
+ """Split a mask into per-object masks."""
+ object_ids = np.unique(mask)
+ object_ids = object_ids[object_ids > 0].tolist()
+ per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
+ return per_obj_mask
+
+
+def put_per_obj_mask(per_obj_mask, height, width):
+ """Combine per-object masks into a single mask."""
+ mask = np.zeros((height, width), dtype=np.uint8)
+ object_ids = sorted(per_obj_mask)[::-1]
+ for object_id in object_ids:
+ object_mask = per_obj_mask[object_id]
+ object_mask = object_mask.reshape(height, width)
+ mask[object_mask] = object_id
+ return mask
+
+
+def load_masks_from_dir(
+ input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False
+):
+ """Load masks from a directory as a dict of per-object masks."""
+ if not per_obj_png_file:
+ input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
+ if allow_missing and not os.path.exists(input_mask_path):
+ return {}, None
+ input_mask, input_palette = load_ann_png(input_mask_path)
+ per_obj_input_mask = get_per_obj_mask(input_mask)
+ else:
+ per_obj_input_mask = {}
+ input_palette = None
+ # each object is a directory in "{object_id:%03d}" format
+ for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
+ object_id = int(object_name)
+ input_mask_path = os.path.join(
+ input_mask_dir, video_name, object_name, f"{frame_name}.png"
+ )
+ if allow_missing and not os.path.exists(input_mask_path):
+ continue
+ input_mask, input_palette = load_ann_png(input_mask_path)
+ per_obj_input_mask[object_id] = input_mask > 0
+
+ return per_obj_input_mask, input_palette
+
+
+def save_masks_to_dir(
+ output_mask_dir,
+ video_name,
+ frame_name,
+ per_obj_output_mask,
+ height,
+ width,
+ per_obj_png_file,
+ output_palette,
+):
+ """Save masks to a directory as PNG files."""
+ os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
+ if not per_obj_png_file:
+ output_mask = put_per_obj_mask(per_obj_output_mask, height, width)
+ output_mask_path = os.path.join(
+ output_mask_dir, video_name, f"{frame_name}.png"
+ )
+ save_ann_png(output_mask_path, output_mask, output_palette)
+ else:
+ for object_id, object_mask in per_obj_output_mask.items():
+ object_name = f"{object_id:03d}"
+ os.makedirs(
+ os.path.join(output_mask_dir, video_name, object_name),
+ exist_ok=True,
+ )
+ output_mask = object_mask.reshape(height, width).astype(np.uint8)
+ output_mask_path = os.path.join(
+ output_mask_dir, video_name, object_name, f"{frame_name}.png"
+ )
+ save_ann_png(output_mask_path, output_mask, output_palette)
+
+
+@torch.inference_mode()
+@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
+def vos_inference(
+ predictor,
+ base_video_dir,
+ input_mask_dir,
+ output_mask_dir,
+ video_name,
+ score_thresh=0.0,
+ use_all_masks=False,
+ per_obj_png_file=False,
+):
+ """Run VOS inference on a single video with the given predictor."""
+ # load the video frames and initialize the inference state on this video
+ video_dir = os.path.join(base_video_dir, video_name)
+ frame_names = [
+ os.path.splitext(p)[0]
+ for p in os.listdir(video_dir)
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
+ ]
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
+ inference_state = predictor.init_state(
+ video_path=video_dir, async_loading_frames=False
+ )
+ height = inference_state["video_height"]
+ width = inference_state["video_width"]
+ input_palette = None
+
+ # fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks)
+ if not use_all_masks:
+ # use only the first video's ground-truth mask as the input mask
+ input_frame_inds = [0]
+ else:
+ # use all mask files available in the input_mask_dir as the input masks
+ if not per_obj_png_file:
+ input_frame_inds = [
+ idx
+ for idx, name in enumerate(frame_names)
+ if os.path.exists(
+ os.path.join(input_mask_dir, video_name, f"{name}.png")
+ )
+ ]
+ else:
+ input_frame_inds = [
+ idx
+ for object_name in os.listdir(os.path.join(input_mask_dir, video_name))
+ for idx, name in enumerate(frame_names)
+ if os.path.exists(
+ os.path.join(input_mask_dir, video_name, object_name, f"{name}.png")
+ )
+ ]
+ # check and make sure we got at least one input frame
+ if len(input_frame_inds) == 0:
+ raise RuntimeError(
+ f"In {video_name=}, got no input masks in {input_mask_dir=}. "
+ "Please make sure the input masks are available in the correct format."
+ )
+ input_frame_inds = sorted(set(input_frame_inds))
+
+ # add those input masks to SAM 2 inference state before propagation
+ object_ids_set = None
+ for input_frame_idx in input_frame_inds:
+ try:
+ per_obj_input_mask, input_palette = load_masks_from_dir(
+ input_mask_dir=input_mask_dir,
+ video_name=video_name,
+ frame_name=frame_names[input_frame_idx],
+ per_obj_png_file=per_obj_png_file,
+ )
+ except FileNotFoundError as e:
+ raise RuntimeError(
+ f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. "
+ "Please add the `--track_object_appearing_later_in_video` flag "
+ "for VOS datasets that don't have all objects to track appearing "
+ "in the first frame (such as LVOS or YouTube-VOS)."
+ ) from e
+ # get the list of object ids to track from the first input frame
+ if object_ids_set is None:
+ object_ids_set = set(per_obj_input_mask)
+ for object_id, object_mask in per_obj_input_mask.items():
+ # check and make sure no new object ids appear only in later frames
+ if object_id not in object_ids_set:
+ raise RuntimeError(
+ f"In {video_name=}, got a new {object_id=} appearing only in a "
+ f"later {input_frame_idx=} (but not appearing in the first frame). "
+ "Please add the `--track_object_appearing_later_in_video` flag "
+ "for VOS datasets that don't have all objects to track appearing "
+ "in the first frame (such as LVOS or YouTube-VOS)."
+ )
+ predictor.add_new_mask(
+ inference_state=inference_state,
+ frame_idx=input_frame_idx,
+ obj_id=object_id,
+ mask=object_mask,
+ )
+
+ # check and make sure we have at least one object to track
+ if object_ids_set is None or len(object_ids_set) == 0:
+ raise RuntimeError(
+ f"In {video_name=}, got no object ids on {input_frame_inds=}. "
+ "Please add the `--track_object_appearing_later_in_video` flag "
+ "for VOS datasets that don't have all objects to track appearing "
+ "in the first frame (such as LVOS or YouTube-VOS)."
+ )
+ # run propagation throughout the video and collect the results in a dict
+ os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
+ output_palette = input_palette or DAVIS_PALETTE
+ video_segments = {} # video_segments contains the per-frame segmentation results
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
+ inference_state
+ ):
+ per_obj_output_mask = {
+ out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy()
+ for i, out_obj_id in enumerate(out_obj_ids)
+ }
+ video_segments[out_frame_idx] = per_obj_output_mask
+
+ # write the output masks as palette PNG files to output_mask_dir
+ for out_frame_idx, per_obj_output_mask in video_segments.items():
+ save_masks_to_dir(
+ output_mask_dir=output_mask_dir,
+ video_name=video_name,
+ frame_name=frame_names[out_frame_idx],
+ per_obj_output_mask=per_obj_output_mask,
+ height=height,
+ width=width,
+ per_obj_png_file=per_obj_png_file,
+ output_palette=output_palette,
+ )
+
+
+@torch.inference_mode()
+@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
+def vos_separate_inference_per_object(
+ predictor,
+ base_video_dir,
+ input_mask_dir,
+ output_mask_dir,
+ video_name,
+ score_thresh=0.0,
+ use_all_masks=False,
+ per_obj_png_file=False,
+):
+ """
+ Run VOS inference on a single video with the given predictor.
+
+ Unlike `vos_inference`, this function run inference separately for each object
+ in a video, which could be applied to datasets like LVOS or YouTube-VOS that
+ don't have all objects to track appearing in the first frame (i.e. some objects
+ might appear only later in the video).
+ """
+ # load the video frames and initialize the inference state on this video
+ video_dir = os.path.join(base_video_dir, video_name)
+ frame_names = [
+ os.path.splitext(p)[0]
+ for p in os.listdir(video_dir)
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
+ ]
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
+ inference_state = predictor.init_state(
+ video_path=video_dir, async_loading_frames=False
+ )
+ height = inference_state["video_height"]
+ width = inference_state["video_width"]
+ input_palette = None
+
+ # collect all the object ids and their input masks
+ inputs_per_object = defaultdict(dict)
+ for idx, name in enumerate(frame_names):
+ if per_obj_png_file or os.path.exists(
+ os.path.join(input_mask_dir, video_name, f"{name}.png")
+ ):
+ per_obj_input_mask, input_palette = load_masks_from_dir(
+ input_mask_dir=input_mask_dir,
+ video_name=video_name,
+ frame_name=frame_names[idx],
+ per_obj_png_file=per_obj_png_file,
+ allow_missing=True,
+ )
+ for object_id, object_mask in per_obj_input_mask.items():
+ # skip empty masks
+ if not np.any(object_mask):
+ continue
+ # if `use_all_masks=False`, we only use the first mask for each object
+ if len(inputs_per_object[object_id]) > 0 and not use_all_masks:
+ continue
+ print(f"adding mask from frame {idx} as input for {object_id=}")
+ inputs_per_object[object_id][idx] = object_mask
+
+ # run inference separately for each object in the video
+ object_ids = sorted(inputs_per_object)
+ output_scores_per_object = defaultdict(dict)
+ for object_id in object_ids:
+ # add those input masks to SAM 2 inference state before propagation
+ input_frame_inds = sorted(inputs_per_object[object_id])
+ predictor.reset_state(inference_state)
+ for input_frame_idx in input_frame_inds:
+ predictor.add_new_mask(
+ inference_state=inference_state,
+ frame_idx=input_frame_idx,
+ obj_id=object_id,
+ mask=inputs_per_object[object_id][input_frame_idx],
+ )
+
+ # run propagation throughout the video and collect the results in a dict
+ for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video(
+ inference_state,
+ start_frame_idx=min(input_frame_inds),
+ reverse=False,
+ ):
+ obj_scores = out_mask_logits.cpu().numpy()
+ output_scores_per_object[object_id][out_frame_idx] = obj_scores
+
+ # post-processing: consolidate the per-object scores into per-frame masks
+ os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
+ output_palette = input_palette or DAVIS_PALETTE
+ video_segments = {} # video_segments contains the per-frame segmentation results
+ for frame_idx in range(len(frame_names)):
+ scores = torch.full(
+ size=(len(object_ids), 1, height, width),
+ fill_value=-1024.0,
+ dtype=torch.float32,
+ )
+ for i, object_id in enumerate(object_ids):
+ if frame_idx in output_scores_per_object[object_id]:
+ scores[i] = torch.from_numpy(
+ output_scores_per_object[object_id][frame_idx]
+ )
+
+ if not per_obj_png_file:
+ scores = predictor._apply_non_overlapping_constraints(scores)
+ per_obj_output_mask = {
+ object_id: (scores[i] > score_thresh).cpu().numpy()
+ for i, object_id in enumerate(object_ids)
+ }
+ video_segments[frame_idx] = per_obj_output_mask
+
+ # write the output masks as palette PNG files to output_mask_dir
+ for frame_idx, per_obj_output_mask in video_segments.items():
+ save_masks_to_dir(
+ output_mask_dir=output_mask_dir,
+ video_name=video_name,
+ frame_name=frame_names[frame_idx],
+ per_obj_output_mask=per_obj_output_mask,
+ height=height,
+ width=width,
+ per_obj_png_file=per_obj_png_file,
+ output_palette=output_palette,
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--sam2_cfg",
+ type=str,
+ default="configs/sam2.1/sam2.1_hiera_b+.yaml",
+ help="SAM 2 model configuration file",
+ )
+ parser.add_argument(
+ "--sam2_checkpoint",
+ type=str,
+ default="./checkpoints/sam2.1_hiera_base_plus.pt",
+ help="path to the SAM 2 model checkpoint",
+ )
+ parser.add_argument(
+ "--base_video_dir",
+ type=str,
+ required=True,
+ help="directory containing videos (as JPEG files) to run VOS prediction on",
+ )
+ parser.add_argument(
+ "--input_mask_dir",
+ type=str,
+ required=True,
+ help="directory containing input masks (as PNG files) of each video",
+ )
+ parser.add_argument(
+ "--video_list_file",
+ type=str,
+ default=None,
+ help="text file containing the list of video names to run VOS prediction on",
+ )
+ parser.add_argument(
+ "--output_mask_dir",
+ type=str,
+ required=True,
+ help="directory to save the output masks (as PNG files)",
+ )
+ parser.add_argument(
+ "--score_thresh",
+ type=float,
+ default=0.0,
+ help="threshold for the output mask logits (default: 0.0)",
+ )
+ parser.add_argument(
+ "--use_all_masks",
+ action="store_true",
+ help="whether to use all available PNG files in input_mask_dir "
+ "(default without this flag: just the first PNG file as input to the SAM 2 model; "
+ "usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)",
+ )
+ parser.add_argument(
+ "--per_obj_png_file",
+ action="store_true",
+ help="whether use separate per-object PNG files for input and output masks "
+ "(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; "
+ "note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)",
+ )
+ parser.add_argument(
+ "--apply_postprocessing",
+ action="store_true",
+ help="whether to apply postprocessing (e.g. hole-filling) to the output masks "
+ "(we don't apply such post-processing in the SAM 2 model evaluation)",
+ )
+ parser.add_argument(
+ "--track_object_appearing_later_in_video",
+ action="store_true",
+ help="whether to track objects that appear later in the video (i.e. not on the first frame; "
+ "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
+ )
+ parser.add_argument(
+ "--use_vos_optimized_video_predictor",
+ action="store_true",
+ help="whether to use vos optimized video predictor with all modules compiled",
+ )
+ args = parser.parse_args()
+
+ # if we use per-object PNG files, they could possibly overlap in inputs and outputs
+ hydra_overrides_extra = [
+ "++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true")
+ ]
+ predictor = build_sam2_video_predictor(
+ config_file=args.sam2_cfg,
+ ckpt_path=args.sam2_checkpoint,
+ apply_postprocessing=args.apply_postprocessing,
+ hydra_overrides_extra=hydra_overrides_extra,
+ vos_optimized=args.use_vos_optimized_video_predictor,
+ )
+
+ if args.use_all_masks:
+ print("using all available masks in input_mask_dir as input to the SAM 2 model")
+ else:
+ print(
+ "using only the first frame's mask in input_mask_dir as input to the SAM 2 model"
+ )
+ # if a video list file is provided, read the video names from the file
+ # (otherwise, we use all subdirectories in base_video_dir)
+ if args.video_list_file is not None:
+ with open(args.video_list_file, "r") as f:
+ video_names = [v.strip() for v in f.readlines()]
+ else:
+ video_names = [
+ p
+ for p in os.listdir(args.base_video_dir)
+ if os.path.isdir(os.path.join(args.base_video_dir, p))
+ ]
+ print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}")
+
+ for n_video, video_name in enumerate(video_names):
+ print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}")
+ if not args.track_object_appearing_later_in_video:
+ vos_inference(
+ predictor=predictor,
+ base_video_dir=args.base_video_dir,
+ input_mask_dir=args.input_mask_dir,
+ output_mask_dir=args.output_mask_dir,
+ video_name=video_name,
+ score_thresh=args.score_thresh,
+ use_all_masks=args.use_all_masks,
+ per_obj_png_file=args.per_obj_png_file,
+ )
+ else:
+ vos_separate_inference_per_object(
+ predictor=predictor,
+ base_video_dir=args.base_video_dir,
+ input_mask_dir=args.input_mask_dir,
+ output_mask_dir=args.output_mask_dir,
+ video_name=video_name,
+ score_thresh=args.score_thresh,
+ use_all_masks=args.use_all_masks,
+ per_obj_png_file=args.per_obj_png_file,
+ )
+
+ print(
+ f"completed VOS prediction on {len(video_names)} videos -- "
+ f"output masks saved to {args.output_mask_dir}"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/phantom/submodules/sam2/training/README.md b/phantom/submodules/sam2/training/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0c829d49d051d8f72e7bef959e33e6f0329c94d
--- /dev/null
+++ b/phantom/submodules/sam2/training/README.md
@@ -0,0 +1,116 @@
+# Training Code for SAM 2
+
+This folder contains the training code for SAM 2, a foundation model for promptable visual segmentation in images and videos.
+The code allows users to train and fine-tune SAM 2 on their own datasets (image, video, or both).
+
+## Structure
+
+The training code is organized into the following subfolders:
+
+* `dataset`: This folder contains image and video dataset and dataloader classes as well as their transforms.
+* `model`: This folder contains the main model class (`SAM2Train`) for training/fine-tuning. `SAM2Train` inherits from `SAM2Base` model and provides functions to enable training or fine-tuning SAM 2. It also accepts all training-time parameters used for simulating user prompts (e.g. iterative point sampling).
+* `utils`: This folder contains training utils such as loggers and distributed training utils.
+* `scripts`: This folder contains the script to extract the frames of SA-V dataset to be used in training.
+* `loss_fns.py`: This file has the main loss class (`MultiStepMultiMasksAndIous`) used for training.
+* `optimizer.py`: This file contains all optimizer utils that support arbitrary schedulers.
+* `trainer.py`: This file contains the `Trainer` class that accepts all the `Hydra` configurable modules (model, optimizer, datasets, etc..) and implements the main train/eval loop.
+* `train.py`: This script is used to launch training jobs. It supports single and multi-node jobs. For usage, please check the [Getting Started](README.md#getting-started) section or run `python training/train.py -h`
+
+## Getting Started
+
+To get started with the training code, we provide a simple example to fine-tune our checkpoints on [MOSE](https://henghuiding.github.io/MOSE/) dataset, which can be extended to your custom datasets.
+
+#### Requirements:
+- We assume training on A100 GPUs with **80 GB** of memory.
+- Download the MOSE dataset using one of the provided links from [here](https://github.com/henghuiding/MOSE-api?tab=readme-ov-file#download).
+
+#### Steps to fine-tune on MOSE:
+- Install the packages required for training by running `pip install -e ".[dev]"`.
+- Set the paths for MOSE dataset in `configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml`.
+ ```yaml
+ dataset:
+ # PATHS to Dataset
+ img_folder: null # PATH to MOSE JPEGImages folder
+ gt_folder: null # PATH to MOSE Annotations folder
+ file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
+ ```
+- To fine-tune the base model on MOSE using 8 GPUs, run
+
+ ```python
+ python training/train.py \
+ -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
+ --use-cluster 0 \
+ --num-gpus 8
+ ```
+
+ We also support multi-node training on a cluster using [SLURM](https://slurm.schedmd.com/documentation.html), for example, you can train on 2 nodes by running
+
+ ```python
+ python training/train.py \
+ -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
+ --use-cluster 1 \
+ --num-gpus 8 \
+ --num-nodes 2
+ --partition $PARTITION \
+ --qos $QOS \
+ --account $ACCOUNT
+ ```
+ where partition, qos, and account are optional and depend on your SLURM configuration.
+ By default, the checkpoint and logs will be saved under `sam2_logs` directory in the root of the repo. Alternatively, you can set the experiment log directory in the config file as follows:
+
+ ```yaml
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
+ ```
+ The training losses can be monitored using `tensorboard` logs stored under `tensorboard/` in the experiment log directory. We also provide a sample validation [split]( ../training/assets/MOSE_sample_val_list.txt) for evaluation purposes. To generate predictions, follow this [guide](../tools/README.md) on how to use our `vos_inference.py` script. After generating the predictions, you can run the `sav_evaluator.py` as detailed [here](../sav_dataset/README.md#sa-v-val-and-test-evaluation). The expected MOSE J&F after fine-tuning the Base plus model is 79.4.
+
+
+ After training/fine-tuning, you can then use the new checkpoint (saved in `checkpoints/` in the experiment log directory) similar to SAM 2 released checkpoints (as illustrated [here](../README.md#image-prediction)).
+## Training on images and videos
+The code supports training on images and videos (similar to how SAM 2 is trained). We provide classes for loading SA-1B as a sample image dataset, SA-V as a sample video dataset, as well as any DAVIS-style video dataset (e.g. MOSE). Note that to train on SA-V, you must first extract all videos to JPEG frames using the provided extraction [script](./scripts/sav_frame_extraction_submitit.py). Below is an example of how to setup the datasets in your config to train on a mix of image and video datasets:
+
+```yaml
+data:
+ train:
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
+ phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases
+ batch_sizes: # List of batch sizes corresponding to each dataset
+ - ${bs1} # Batch size of dataset 1
+ - ${bs2} # Batch size of dataset 2
+ datasets:
+ # SA1B as an example of an image dataset
+ - _target_: training.dataset.vos_dataset.VOSDataset
+ training: true
+ video_dataset:
+ _target_: training.dataset.vos_raw_dataset.SA1BRawDataset
+ img_folder: ${path_to_img_folder}
+ gt_folder: ${path_to_gt_folder}
+ file_list_txt: ${path_to_train_filelist} # Optional
+ sampler:
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
+ num_frames: 1
+ max_num_objects: ${max_num_objects_per_image}
+ transforms: ${image_transforms}
+ # SA-V as an example of a video dataset
+ - _target_: training.dataset.vos_dataset.VOSDataset
+ training: true
+ video_dataset:
+ _target_: training.dataset.vos_raw_dataset.JSONRawDataset
+ img_folder: ${path_to_img_folder}
+ gt_folder: ${path_to_gt_folder}
+ file_list_txt: ${path_to_train_filelist} # Optional
+ ann_every: 4
+ sampler:
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
+ num_frames: 8 # Number of frames per video
+ max_num_objects: ${max_num_objects_per_video}
+ reverse_time_prob: ${reverse_time_prob} # probability to reverse video
+ transforms: ${video_transforms}
+ shuffle: True
+ num_workers: ${num_train_workers}
+ pin_memory: True
+ drop_last: True
+ collate_fn:
+ _target_: training.utils.data_utils.collate_fn
+ _partial_: true
+ dict_key: all
+```
diff --git a/phantom/submodules/sam2/training/__init__.py b/phantom/submodules/sam2/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/phantom/submodules/sam2/training/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/phantom/submodules/sam2/training/assets/MOSE_sample_train_list.txt b/phantom/submodules/sam2/training/assets/MOSE_sample_train_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..28b22e3170f63de0fba3c77ef999f958cd6c48ff
--- /dev/null
+++ b/phantom/submodules/sam2/training/assets/MOSE_sample_train_list.txt
@@ -0,0 +1,1246 @@
+28191f94
+662487fe
+80906bf9
+7e704f2e
+efa25913
+b6f03bd9
+6834d249
+5a723c30
+07779415
+4ce088c6
+199995b5
+54273925
+4fa342f5
+110da3cf
+65856fa0
+46705bb3
+d869a3cf
+555aa049
+8f01fb2c
+37b07a28
+5e80b3dd
+ba0e4dd4
+6f5144b6
+acec8407
+93723f88
+c7c7528c
+97f58761
+e71f9faa
+e64c13dc
+8830d59d
+0e4aeed9
+63437cf3
+95215aa1
+255f86ef
+dc54aab2
+327cd258
+198021ad
+c690220c
+d25ff89d
+7875b874
+4fa6d325
+9fc933f6
+4d8baafe
+55ae6921
+6a3bc149
+89f8163f
+2d65d2ac
+dba172b1
+a14de179
+4017d1b3
+52ddf44c
+3ba93641
+34a5f964
+da7dee28
+872b76de
+1dc12eca
+265a69f4
+86a2b59f
+51e5ca25
+ddf80bcd
+6786602e
+4fa28c89
+f56942e9
+2184bb93
+d883e976
+bfe1469e
+bc4e7b11
+1c80acb0
+2b0e34d3
+56b9ce41
+15f0b0cd
+cc5d0dd1
+1b7eada8
+7286b176
+0ab42ab1
+adb82dc9
+c060b1e6
+3da63bd5
+5488796e
+d7066e20
+aab5ed11
+17f66311
+24df9789
+208fa934
+7ce2c865
+debe4249
+4c56bbea
+149dbae2
+beb693c9
+49eb0315
+e7ad4717
+4e016d5a
+95e24093
+07b5d86c
+80701b6c
+337dfa1e
+b624a46e
+3f849de8
+5db21df2
+47891b4c
+a966d7fd
+013103f6
+da5e4bc5
+ba9ea03d
+526195de
+57f3a53e
+b3aff7f8
+26048547
+bb7ee856
+aef0d049
+e35a8262
+57ad022e
+f45d3823
+e5e9eb29
+39cc637e
+a4fc4f17
+dd5a4739
+bbe97d18
+33602f6b
+9061dac9
+23454d80
+a20baeec
+794f01d4
+02de2f2a
+055fca57
+a69df343
+e307510e
+d07ad1be
+1fc5e086
+db6533a5
+fe9706b7
+87e32230
+8ba58e4c
+561f6380
+2ab9ba0f
+86571569
+756cc6c9
+aa185af5
+c6d7f94b
+7f54c579
+71f4b40e
+4190c83a
+fef0aba4
+2f7c71bb
+e4b6f2ef
+76adaeea
+11cdeb64
+733f2a02
+e50dbddb
+f643141f
+d2e75e95
+84559bc3
+7ade3068
+e69db797
+0b787263
+57895315
+d7969c29
+62529cd4
+203733e7
+48fd97a6
+723fd024
+849f0efb
+aafea009
+dd4eb8f1
+d18554ae
+f3c0f0cf
+90fe55b9
+b0ffaf3b
+e79ecd47
+d670ce7b
+56a5643a
+90ff1d09
+1fb378d9
+57014c7d
+994ed763
+5bc7ea74
+e99bd793
+cbb66185
+5f3fcff6
+05ed1023
+85efa9e3
+652929ce
+905d8740
+a6fcde01
+0fdf67f7
+a5cf4c8d
+e1c48bdd
+782551f7
+6acd353f
+c30641cf
+81d12756
+51befc31
+9d5ab5ca
+d262b7e4
+2cd705a9
+f7360199
+d3f3bf9d
+028f6f64
+94767cb4
+3a739934
+72433603
+ec66879d
+6149becc
+5845c157
+c5082b3c
+f89b54d0
+f3ada126
+409dcb8a
+4411fdee
+eb93ed20
+9cb1ba0e
+b8e1ec26
+7edd8b4f
+5e9412c0
+2744f35a
+dafeb75e
+f3f072f2
+6f1df574
+5a064706
+89c76ac4
+a6adef89
+76303516
+dbd67417
+a53ef3fa
+10552818
+ac7deb19
+2d403c59
+55c157f1
+214aeac3
+a9f5e251
+d7807996
+d1dba33b
+1367e367
+44476e77
+0644075b
+eda37457
+f2de4198
+9a4ce701
+46e00caf
+2ae75f99
+cd49fb99
+4e4483e7
+a0669957
+a6f0d882
+9ce1d54a
+1fc2314b
+21f363b3
+32ecef67
+70bcaf68
+115348f9
+60827ada
+a218e951
+6d30d5ac
+6da17988
+f22c39ce
+5825f0e0
+f415f9ad
+0d4feda2
+832fc243
+414ca58b
+a92390a0
+ddd383cc
+43dc67f7
+962ae0e2
+6dd74e7b
+2bcd6c3b
+b394847f
+637fd121
+d46e771b
+f6bfc699
+63f138de
+932ad0a6
+2080824a
+52fa9174
+843d3bf7
+f3431885
+5c20c48a
+134a2ab0
+2ea465de
+f6786ab5
+2bf49664
+a49ce97b
+6a50e93a
+a7c21e95
+616ad8ec
+0a8d7b41
+b0c90527
+2d893fb7
+19310598
+7744dc51
+4539b907
+9d299f60
+e495537a
+0b02886a
+f4c4a2ca
+e957b2b5
+e6f3bf07
+258944c8
+54364322
+ebb77f95
+0af03282
+cbdbc6c3
+494ecef0
+ee91f783
+9698f06e
+11e16068
+b942ce0a
+423a50e6
+fb16e746
+9c88ae45
+8620c024
+d3af3c85
+780a25de
+e569a15f
+c4f9f19e
+1106f3a7
+d37e29a7
+e53611da
+fdb2e432
+18ad3117
+6fcd426d
+3bfa8379
+3b19c5c3
+ff1142df
+cd182615
+b60ea255
+b3f5d019
+6dc5e55d
+103166c7
+37af9ac1
+ad1881d1
+731149b3
+90e3338a
+6aa0b6f2
+a25316a3
+dc8679e0
+571fb490
+80afed16
+983a551b
+a58578e5
+2bc0bba4
+1143b3fe
+fdd8dd49
+7fe2bf77
+890ef032
+8466eeb2
+c791ddbb
+631b82bd
+78bf9b51
+a99df45f
+2bdb692f
+e89b1501
+4e6aa1e8
+e5665030
+fe21fd5c
+635577d5
+4414cd3a
+03c99e83
+ff041cd1
+c33adbc2
+a988ec74
+576031e0
+03c21af7
+79b25f4b
+bbc485d6
+d36d5a0d
+efdab888
+b20e6781
+81fdc526
+e1c26a53
+7c6d3504
+52a04667
+f22e34d4
+bb936ead
+13f0606c
+d2abc61e
+af509e8f
+bea1c144
+e15e4de8
+e727099f
+b30744df
+ffb6a2e4
+0d31d3a6
+a23048fe
+7d452630
+6c736334
+046ed4f4
+94f4c2aa
+c290cfd3
+f7203226
+2fdae3c5
+7c78e351
+02b72b8d
+2d22d3be
+ba28d02e
+197f6587
+43199a98
+b563b04f
+9293b755
+9cef7489
+d156b96f
+15e9161e
+6d094cd5
+0d876a65
+c818d30a
+8094b12b
+a4a8e24b
+14655f54
+11c14893
+8a48f62a
+7f3d9c22
+d952481c
+03e0f9b8
+28980657
+6a0b5563
+5879983c
+37549a79
+4a7162bd
+7a6aa1ef
+0dc1b78c
+f6dba17b
+1dba51af
+b2f4d608
+e2e6f421
+464066da
+5d24e4ea
+1e75004d
+a02ed92c
+673adbcc
+c2a0c0fd
+85addee5
+54b8f502
+f5d2d8d3
+a19507e1
+803e1756
+0d1fe009
+5968c2d8
+b926e1ad
+a9162e14
+ae470d2b
+bd731802
+68c879f2
+21fe05d9
+c1ed21d0
+831498e4
+cc45a7f2
+cb170015
+59750be4
+30d1cb6b
+03e5f069
+106d33db
+3f003746
+3e5ad020
+8bc5a91c
+64b89eb5
+bfd28682
+f8687b9a
+7bbf38ee
+d6d92b30
+ceaa6c65
+677c8ed7
+dc33acf8
+cfd1de31
+e5be4781
+85585220
+5d2316f6
+dd3f4a07
+34535f5f
+3ae0bc5d
+f521e3c5
+74c2284f
+12a42fd9
+61403519
+88cd32f3
+662a1846
+825a1944
+cf376cf1
+8465d99c
+61a2e246
+62d44645
+103b3ca8
+c7e745ed
+4ed71139
+230c2edf
+529c6889
+9e509c0d
+54b9dea2
+a8934c0d
+29cffe2f
+48017512
+c9f7f69d
+ce691ee6
+21c89360
+3b97c07b
+ebd82d35
+2895bb8b
+7043c5c1
+85d694d7
+88fd7507
+18d8931e
+aa718745
+89b671bb
+0d8d30ae
+26163977
+a6121689
+1589579d
+159789c4
+f5ca8271
+fcc16740
+3158be0b
+860fc1f7
+3f54a330
+82f24ce7
+069f6a2a
+2fa9c523
+c9f1d87f
+efe9cbca
+8f969ea5
+4f5db794
+62c501f8
+2d3b0320
+c99637f0
+0f3b1fcb
+6e4ee861
+e0d9aff0
+230ddb91
+e14d1f96
+c83aa6a1
+eabdf66a
+6783a303
+81659eb2
+ce954bd7
+9a48c0c9
+0ab807b4
+f0617f71
+fe86f2f8
+61d80e22
+e4b6d2a0
+ac093040
+0e05fabe
+d0b507c3
+3d828137
+c4fa0bab
+f7783321
+ec27366a
+404e4c58
+073baf48
+0f685e01
+b0e98fdd
+b4891f7f
+a46b7b77
+ee059f99
+3c87888e
+8d23ddcc
+2d8d7d35
+5680be79
+fc79c03e
+20660b72
+53f67585
+90956534
+7e709e2d
+dae93f5c
+54b9dbba
+cc41ba05
+1e207fe0
+a9c6abf2
+35e0ca09
+e3dcd186
+1b8bb699
+92162474
+cdad6812
+50b91533
+570215ac
+6042d64a
+b6e2c041
+08746283
+7a056996
+b8651773
+adf443e1
+6a6e0e3b
+886ed981
+c1d57fea
+43030c4c
+7ebfbf57
+0770ad03
+e85301d5
+31ac3d98
+acaef45e
+8f415dd1
+fe2dc281
+2c0b9d99
+8e24501e
+911ec4ad
+8036b58e
+c3b350b9
+b6cadd11
+a3a80cf7
+88ab50cd
+59c755a8
+1339321a
+91b2f707
+97b0811e
+1da33959
+31b09833
+c1a40349
+708098a9
+1f220f98
+999e07cb
+0b5e5d29
+94c63453
+b826d642
+a598602d
+4c83eab8
+2efd5e50
+6ec5da3a
+9fcd95eb
+9a2c6b5b
+c205a718
+e638e950
+cb43141c
+494dd91d
+c4957274
+4975a81d
+a1f4c54d
+51e6fafa
+514490e5
+b0d09e6a
+c6726eb8
+06772c9a
+5a65ffd7
+3657c62b
+03012cfd
+529df209
+f1c38e66
+ab417352
+118a067e
+8957514f
+22e8b380
+3b1a4616
+a4457543
+57c9f6e0
+e362c16b
+0f809e41
+857e375e
+9cff25e3
+d754fb65
+6ad44b86
+051052d8
+a4564b94
+f68507d0
+80a7cf7b
+ad8cd1e0
+60b19cd3
+274fe944
+f06632aa
+628a337b
+92c96c05
+87fc565c
+6f6e6c37
+228a0234
+6487110a
+aa911a8e
+40c47fa3
+9606508b
+6ba9e61f
+c8c1d5a9
+cf01df5b
+9421b9ad
+006e6b64
+1c28e081
+06273084
+8925e11b
+b46c822b
+00501424
+cfd946b2
+2e92a7dc
+1c5f5bb6
+1d29944c
+8248698e
+19247506
+1eac1aff
+ee9caa47
+4a41cbf8
+d97c9309
+4ca87c14
+9707f1e3
+8bb9a221
+6605e67d
+95cf72d7
+1c6fb814
+033130b2
+4344808d
+5f14e5d2
+a810399b
+e325a6d4
+7014ddf4
+725d4bfb
+790285e8
+1a6a731f
+fbfb6e30
+0d4d88f6
+80ce18a4
+572495b7
+4b44dc50
+95dce33c
+4a6fb202
+3142014e
+a3c56751
+96b2a414
+c4aa176c
+fd1e394f
+93f0f509
+f494e9fa
+bfa42a75
+db5319c7
+aa92e070
+81220a93
+e4a72496
+fc467bf1
+5397b01d
+1dc0c9a0
+f6f8b4a6
+53dc7db4
+8ef303eb
+62ca45c9
+e9d3465e
+3784e3f6
+8c934e67
+5ba84e3f
+30e41f1e
+61cf0ec8
+e93e8f01
+fc6086dd
+a95f0aea
+33a04ef2
+6f295adb
+d2aa8c66
+724cc810
+d8623d26
+8d0d641a
+4bda7a76
+38030c69
+56199c41
+d2f4b9e2
+a7b8ac96
+64044df1
+fd1078cc
+0165667b
+16e1cca7
+915f0d9a
+eeaaa67e
+378430d5
+a84c60e6
+b4ae36cc
+2a3a0571
+13e6df75
+aa348c45
+59d7a11d
+68954daf
+d6f883c6
+f28b429a
+32dc49d4
+ccf14ee0
+7d512591
+9bdabdb2
+ed878d94
+54eda06d
+132561ee
+3c4b6736
+0367af42
+531c1c36
+843d8f25
+333bdbdc
+c3c21268
+07b00746
+c7fe0584
+49fc9f2e
+9ed4317a
+d29991b4
+98b0033d
+f0b922bf
+89fe6899
+58264713
+2f49220a
+6ff85ca5
+4b96b2c8
+a42f54f5
+aa425600
+22fdee40
+dde85a9d
+3722f6fe
+e7529cbc
+5ae23f9f
+cc32235b
+730bc486
+b12701b7
+a96b3010
+16130bd3
+2c713560
+f7935d24
+a7eb6616
+0d6e7177
+100edaef
+0442a954
+60f4fa43
+37bf7edf
+76b18413
+ab0646a9
+c575434d
+1e356390
+5416fbb7
+df7cf932
+269872de
+9033b607
+c2e88575
+932542cd
+23e046fb
+3d08dadd
+7999adc5
+ed81c485
+3bd7facd
+1feae28e
+8d72533b
+6a8d35d6
+65308bdc
+7f0b7662
+98290486
+fee3371f
+c463c7e5
+faf7d852
+75c34dc5
+96a6722e
+e5605136
+851bc5d9
+15c41c4b
+6a39e104
+5fbff256
+0e7001dd
+5411113f
+3ea2f7f2
+242b74b1
+87727003
+ec6dd0e9
+980baf58
+9d0b7bf1
+9113c9d4
+5ebef6bd
+a5f70ce7
+b0240233
+06ad78e0
+8745edd0
+d8e8d984
+ac32a655
+38568758
+d48c552d
+0b27d5f7
+c65d0736
+800e3c14
+d37a5857
+bcebc660
+d3ab52cc
+405e3ee7
+e33cddc9
+b0197182
+89fd5681
+9e192417
+8554c402
+aae923b8
+31af515d
+75b26f88
+60471744
+460945aa
+c0fe8e1a
+1731babb
+2e85e35d
+f9c20062
+115da184
+ddfa88c7
+359003f8
+dfa99126
+bf04814f
+f407a414
+e18723c4
+0a7a3629
+c07ab37e
+1251a1c9
+4d09d22a
+5984ed74
+34504f63
+ced51047
+08ff419c
+d942e98c
+2697f864
+3b671a61
+72a2f7e2
+48e7cafe
+6adad2f7
+18840617
+1e44f47e
+36cc4055
+8c494902
+2982de7a
+6a428397
+c4a0ecfb
+231d6945
+fe470104
+f93e1bd0
+bd18bc5a
+7bd70d93
+8f81a0ee
+db78e7a1
+7593caea
+86d5b29b
+5457b298
+0d967fd1
+62372d4c
+68259db3
+f0944ea2
+7b017dbf
+bcb6e338
+03692b14
+f7d36a47
+1ca2531a
+6728528d
+1fc0e6a8
+0ba9c5ad
+a386eaa2
+b0c5459f
+1d64aff3
+b97d4f1a
+b3745d91
+c461003e
+910bf878
+ae42601c
+8d2ddeff
+aaecaa39
+250b5034
+edb11192
+7bfe9b57
+6d533759
+51586b36
+a38d648a
+8fdb48e5
+6075d6b0
+3588ea03
+bc844942
+398d41f5
+660e3b70
+0b99f522
+f169fd1b
+7bfa2ab5
+ab461319
+25153e58
+002b4dce
+a2df1bee
+550a7357
+b604f2dd
+2f477d05
+bdf9eb5a
+857ddc6e
+c8f0fd41
+6df96f15
+e147ab26
+788da8e8
+02221fb0
+d1d95c61
+a3f0cb28
+3a6e6ace
+67c2909a
+220382ab
+eaed776d
+aff08a61
+b99d1bd6
+9d9ae988
+34ccea00
+41dae436
+18513251
+ad57acd1
+67f110fc
+3f09f5c9
+25ef7d43
+12a5d0d7
+3ff48b8b
+26ed56e6
+c047a092
+bb8639e1
+8788747f
+584838d4
+f8e5f837
+657242e8
+cb8eedf4
+74a917f1
+578f71da
+c9b27125
+22e1f53c
+f40145c2
+4795259b
+3f313a2f
+c9012bf6
+22167a50
+6e7f9437
+ef51a724
+356e0fcb
+d3ea999d
+08a5c662
+85aa3b0e
+579fadec
+7bc95dc2
+c097af8e
+f01d8b9f
+80fb79c6
+ea65e6b7
+29ff29f6
+9e1f739d
+b7fb59c9
+e2160f17
+0be33bc1
+e96b9b04
+b1affe79
+c4f4b2e2
+f4c8ffb1
+6a009e50
+a8828854
+2786f841
+a64e724c
+5f54d077
+7040385d
+6e0f0ecc
+f33d3c15
+8108b358
+46a502de
+1e0fb02a
+ddbdfa32
+e7b34ab6
+c9080ed1
+395224b3
+33f9ab47
+c245ecda
+c28d81a9
+37303a3b
+6380dd6f
+2fb5a55b
+83b7c53c
+41c8d0d2
+3aab2d13
+dc7d21fb
+86a88668
+37bb38fe
+ab6413a8
+bbe585b2
+a0ca072a
+9d5940d2
+ddb1d0b1
+a946317a
+988b29a4
+89dc0432
+5df8490d
+5e167efa
+50a86faa
+fe6a535a
+a9f8b8b4
+6e2dce1b
+d0696759
+c09da3b2
+f07dd347
+67408899
+406165ff
+a4a9d03d
+9b5f0f47
+5f3e8022
+1d7a23e0
+25af2eeb
+82a3db34
+c9351029
+6c93d44c
+f088ad1c
+9ee59f51
+b5276b3f
+ca74a924
+781af187
+fa3e0b85
+b898c99e
+1ca51f06
+5a92a0c1
+138c81fe
+d0722d0f
+05a7d84d
+e18f1dea
+799a2d61
+8276e558
+f0ba8748
+ce733e8a
+2f9d0911
+58f24fa4
+66a25278
+3135d31d
+4b9223ee
+bdd5e6b3
+ddbebec1
+8dbebbd9
+3020b38f
+e607450d
+724a5d1c
+91b754c5
+2e85e790
+3a407bd9
+fd137178
+a304029b
+4023fc77
+440d5072
+2eb73c7c
+164a7305
+b33ade7c
+277ad883
+b0f7e75c
+74107936
+83924bdb
+b72beb78
+86c01d64
+f6f441eb
+23b9a3ea
+80b73f1a
+93c6411d
+1e95ef5e
+800b5eac
+9519832a
+ae043406
+b06a902e
+1dbca5cc
+571f88a1
+b1faf52b
+45572497
+8d016cdb
+f92cdae8
+316931f8
+f9884439
+e1b7f212
+e23c6392
+ccfae073
+5aa1efda
+74f0687c
+eaff3301
+b6520a94
+c5398714
+15e7e4d1
+0fc00006
+8cf49218
+3a8ddc0a
+e7e2a0b9
+eec4c008
+8d73085e
+77e246da
+00e92ab4
+f76f6cf9
+19801183
+233406ef
+b80e028c
+342c0b2a
+a2768c47
+99350a74
+adbd400b
+f3978ade
+b87a4f6c
+fa95a6a2
+6dff20c9
+935b5ad8
+dbbbb401
+1b6472c1
+9c0e6331
+04ae7a6b
+4c94e4f3
+90cb46cb
+2831ecf5
+ff77a145
+79af6097
+ba61a719
+abcb7665
+7e87750e
+c4c7bc5d
+3a670b81
+3d9a7023
+82667d52
+a4587f62
+ca619b7f
+7c5462f5
+bda5c60d
+e6e48ac8
+405c6000
+7981f344
+f7375ab3
+bb467ff9
+cfc68a82
+e417a6d8
+1a6177c1
+7b75dace
+b1af350d
+484d48a3
+1f805416
+7416ab4e
+1291276c
+9e85179b
+5a74660c
+7e6d00df
+01e3cec8
+ee2c0688
+f6de8226
+a217538c
+b432c3ef
+49e5ff4e
+035359e5
+8ae8e7ed
+2da12766
+cac39070
+115adda4
+1a2872dc
+fac3378e
+294e7bf8
+a1a4991f
+c062f4d7
+72b2b77d
+158062aa
+9ae447a7
+a7b05677
+fdfd5d56
+eac1a9e6
+a5905593
+59992293
+84298fae
+f708e55f
+093d3d93
+75d26197
+924f5d88
+3184a7ec
+b454fdbc
+2d9101b8
+ae70fb7c
+4385b2c4
+63b37343
+0b4b662c
+2883ae72
+ffcab778
+0f96e2d7
+897066e3
+f23e98ad
+797a7b7e
+2fc476f9
diff --git a/phantom/submodules/sam2/training/assets/MOSE_sample_val_list.txt b/phantom/submodules/sam2/training/assets/MOSE_sample_val_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9721028718245ff5297fdae59d35a7c89cb5f56a
--- /dev/null
+++ b/phantom/submodules/sam2/training/assets/MOSE_sample_val_list.txt
@@ -0,0 +1,200 @@
+32e5d721
+5bad0bab
+267bfd6c
+0a43a414
+56c56ca9
+9a1146b3
+c6ad7aaf
+78a1f4b1
+fc455e73
+072e7b3f
+77ccb57d
+a76ee415
+8cdcfc17
+5d518b42
+376dd830
+0e843fc8
+2af0e766
+2bd4e845
+de2f2a6a
+ade9ee91
+001ca3cb
+fc4c1c67
+8ef55579
+b84ce852
+4cc8528a
+767ffaaa
+112a2ef0
+a338c8aa
+cbd144f5
+5ff72128
+86a949e2
+9f2323ac
+1fab1d1c
+75924351
+ef55817b
+02deca50
+4d979d99
+4d65f873
+28470fa0
+0d1575fe
+06ea172e
+29a6ddc2
+797f1bec
+780e7a99
+b9ed5b44
+02a236b4
+607d8ff5
+af5666b2
+0558d0ed
+a938c6b2
+103df575
+77110e80
+739e5a07
+6763a576
+06ebc138
+ba4b3b09
+b35cc2f3
+4e0597a0
+5949ee84
+5348d547
+323c4236
+b3b51117
+55727ddd
+ab2714f3
+d2878895
+c0734cb3
+94f7c53e
+2a2745e5
+442ffb54
+3592425a
+50ae03b0
+5f150435
+3067f9fa
+9ffb2818
+adeaf5aa
+31caacec
+1cd99b86
+aa22f9d0
+8fa50320
+e6348d2c
+42ff84a5
+8c8b7913
+c96adcbc
+495be321
+db735509
+ee113fc4
+a678cdab
+c409ca4d
+68d2b259
+592b4dee
+4e2b4dc7
+eb4d26e1
+2009a00f
+bec5c89d
+67191f24
+a3e85b4b
+da7080cd
+80d978e9
+36dcb93f
+a41e8c44
+12fdc864
+46d140ea
+657c9dd9
+a86f84ee
+90c1c43d
+33015509
+afc7664d
+23df06e1
+291d4799
+0ab75563
+251bf059
+bcefdcc4
+ce9a2796
+94d3403a
+8f2e04bc
+f9cda066
+9dfa2cc5
+66924c91
+e765a09e
+15654ee1
+48e0bd39
+ee095221
+2463609b
+544d0d1f
+51b8c2e1
+d321dde4
+4cb11a5f
+d7058a0d
+37af282a
+fabae187
+7be91184
+181ec185
+2d16ceeb
+b56be4b1
+6699eff0
+79acac96
+d61c4665
+0c13e1e7
+100f6ecf
+71217dfc
+82df0888
+4c42c747
+c9fdf703
+d2efeb4b
+69ed9d14
+64914fb6
+255bedbc
+4ea934d8
+a034feb2
+e4f4ddae
+e36a3026
+c1489591
+111bb373
+e1d9fb32
+93e22d48
+c1ec4b26
+d9638e69
+60ab04c5
+cfe7773a
+62132822
+2f5fb2a3
+7bdd197d
+033333fd
+130fcdbe
+12e509c2
+67138c33
+6f90cc5f
+4e3020fe
+bbdd8bb7
+b399ccdb
+fecd10d2
+2e0967f7
+f509054f
+792c6ff7
+48e2afc5
+d904c048
+111e0a5c
+b83024e2
+e6a7b79c
+bdc5ccf7
+b8146d00
+9d394f1a
+645b84f9
+95ab2d0f
+e6f8a31d
+b4f876fb
+dc2c570d
+3afd02d7
+5c80c82c
+b1b32ddd
+9f25fc61
+ba538072
+f8916fef
+43c04ad2
+a658e949
+2861dd53
+f6e40aba
+09d305d1
+aac33bff
+8d9d4c08
diff --git a/phantom/submodules/sam2/training/dataset/__init__.py b/phantom/submodules/sam2/training/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/phantom/submodules/sam2/training/dataset/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/phantom/submodules/sam2/training/dataset/sam2_datasets.py b/phantom/submodules/sam2/training/dataset/sam2_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..6deda056bea555fc07ace455ccc62c606a7b81c9
--- /dev/null
+++ b/phantom/submodules/sam2/training/dataset/sam2_datasets.py
@@ -0,0 +1,180 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+from typing import Callable, Iterable, List, Optional, Sequence
+
+import torch
+
+from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
+
+from torch.utils.data.distributed import DistributedSampler
+
+
+class MixedDataLoader:
+ def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
+ """
+ Args:
+ dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
+ mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
+
+ """
+ assert len(dataloaders) == mixing_prob.shape[0]
+ self.dataloaders = dataloaders
+ self.mixing_prob = mixing_prob
+ # Iterator state
+ self._iter_dls = None
+ self._iter_mixing_prob = None
+ self.random_generator = torch.Generator()
+
+ def __len__(self):
+ return sum([len(d) for d in self.dataloaders])
+
+ def __iter__(self):
+ # Synchronize dataloader seeds
+ self.random_generator.manual_seed(42)
+ self._iter_dls = [iter(loader) for loader in self.dataloaders]
+ self._iter_mixing_prob = self.mixing_prob.clone()
+ return self
+
+ def __next__(self):
+ """
+ Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
+ """
+ if self._iter_dls is None:
+ raise TypeError(f"{type(self).__name__} object is not an iterator")
+
+ while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
+ dataset_idx = self._iter_mixing_prob.multinomial(
+ 1, generator=self.random_generator
+ ).item()
+ try:
+ item = next(self._iter_dls[dataset_idx])
+ return item
+ except StopIteration:
+ # No more iterations for this dataset, set it's mixing probability to zero and try again.
+ self._iter_mixing_prob[dataset_idx] = 0
+ except Exception as e:
+ # log and raise any other unexpected error.
+ logging.error(e)
+ raise e
+
+ # Exhausted all iterators
+ raise StopIteration
+
+
+class TorchTrainMixedDataset:
+ def __init__(
+ self,
+ datasets: List[Dataset],
+ batch_sizes: List[int],
+ num_workers: int,
+ shuffle: bool,
+ pin_memory: bool,
+ drop_last: bool,
+ collate_fn: Optional[Callable] = None,
+ worker_init_fn: Optional[Callable] = None,
+ phases_per_epoch: int = 1,
+ dataset_prob: Optional[List[float]] = None,
+ ) -> None:
+ """
+ Args:
+ datasets (List[Dataset]): List of Datasets to be mixed.
+ batch_sizes (List[int]): Batch sizes for each dataset in the list.
+ num_workers (int): Number of workers per dataloader.
+ shuffle (bool): Whether or not to shuffle data.
+ pin_memory (bool): If True, use pinned memory when loading tensors from disk.
+ drop_last (bool): Whether or not to drop the last batch of data.
+ collate_fn (Callable): Function to merge a list of samples into a mini-batch.
+ worker_init_fn (Callable): Function to init each dataloader worker.
+ phases_per_epoch (int): Number of phases per epoch.
+ dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
+ """
+
+ self.datasets = datasets
+ self.batch_sizes = batch_sizes
+ self.num_workers = num_workers
+ self.shuffle = shuffle
+ self.pin_memory = pin_memory
+ self.drop_last = drop_last
+ self.collate_fn = collate_fn
+ self.worker_init_fn = worker_init_fn
+ assert len(self.datasets) > 0
+ for dataset in self.datasets:
+ assert not isinstance(dataset, IterableDataset), "Not supported"
+ # `RepeatFactorWrapper` requires calling set_epoch first to get its length
+ self._set_dataset_epoch(dataset, 0)
+ self.phases_per_epoch = phases_per_epoch
+ self.chunks = [None] * len(datasets)
+ if dataset_prob is None:
+ # If not provided, assign each dataset a probability proportional to its length.
+ dataset_lens = [
+ (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
+ for d, bs in zip(datasets, batch_sizes)
+ ]
+ total_len = sum(dataset_lens)
+ dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
+ else:
+ assert len(dataset_prob) == len(datasets)
+ dataset_prob = torch.tensor(dataset_prob)
+
+ logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
+ assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
+ self.dataset_prob = dataset_prob
+
+ def _set_dataset_epoch(self, dataset, epoch: int) -> None:
+ if hasattr(dataset, "epoch"):
+ dataset.epoch = epoch
+ if hasattr(dataset, "set_epoch"):
+ dataset.set_epoch(epoch)
+
+ def get_loader(self, epoch) -> Iterable:
+ dataloaders = []
+ for d_idx, (dataset, batch_size) in enumerate(
+ zip(self.datasets, self.batch_sizes)
+ ):
+ if self.phases_per_epoch > 1:
+ # Major epoch that looops over entire dataset
+ # len(main_epoch) == phases_per_epoch * len(epoch)
+ main_epoch = epoch // self.phases_per_epoch
+
+ # Phase with in the main epoch
+ local_phase = epoch % self.phases_per_epoch
+
+ # Start of new data-epoch or job is resumed after preemtion.
+ if local_phase == 0 or self.chunks[d_idx] is None:
+ # set seed for dataset epoch
+ # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
+ self._set_dataset_epoch(dataset, main_epoch)
+
+ # Separate random generator for subset sampling
+ g = torch.Generator()
+ g.manual_seed(main_epoch)
+ self.chunks[d_idx] = torch.chunk(
+ torch.randperm(len(dataset), generator=g),
+ self.phases_per_epoch,
+ )
+
+ dataset = Subset(dataset, self.chunks[d_idx][local_phase])
+ else:
+ self._set_dataset_epoch(dataset, epoch)
+
+ sampler = DistributedSampler(dataset, shuffle=self.shuffle)
+ sampler.set_epoch(epoch)
+
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
+ dataloaders.append(
+ DataLoader(
+ dataset,
+ num_workers=self.num_workers,
+ pin_memory=self.pin_memory,
+ batch_sampler=batch_sampler,
+ collate_fn=self.collate_fn,
+ worker_init_fn=self.worker_init_fn,
+ )
+ )
+ return MixedDataLoader(dataloaders, self.dataset_prob)
diff --git a/phantom/submodules/sam2/training/dataset/transforms.py b/phantom/submodules/sam2/training/dataset/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e5c6512ac7fd9548273fb152a3b57ef75e4fc18
--- /dev/null
+++ b/phantom/submodules/sam2/training/dataset/transforms.py
@@ -0,0 +1,528 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Transforms and data augmentation for both image + bbox.
+"""
+
+import logging
+
+import random
+from typing import Iterable
+
+import torch
+import torchvision.transforms as T
+import torchvision.transforms.functional as F
+import torchvision.transforms.v2.functional as Fv2
+from PIL import Image as PILImage
+
+from torchvision.transforms import InterpolationMode
+
+from training.utils.data_utils import VideoDatapoint
+
+
+def hflip(datapoint, index):
+
+ datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
+ for obj in datapoint.frames[index].objects:
+ if obj.segment is not None:
+ obj.segment = F.hflip(obj.segment)
+
+ return datapoint
+
+
+def get_size_with_aspect_ratio(image_size, size, max_size=None):
+ w, h = image_size
+ if max_size is not None:
+ min_original_size = float(min((w, h)))
+ max_original_size = float(max((w, h)))
+ if max_original_size / min_original_size * size > max_size:
+ size = max_size * min_original_size / max_original_size
+
+ if (w <= h and w == size) or (h <= w and h == size):
+ return (h, w)
+
+ if w < h:
+ ow = int(round(size))
+ oh = int(round(size * h / w))
+ else:
+ oh = int(round(size))
+ ow = int(round(size * w / h))
+
+ return (oh, ow)
+
+
+def resize(datapoint, index, size, max_size=None, square=False, v2=False):
+ # size can be min_size (scalar) or (w, h) tuple
+
+ def get_size(image_size, size, max_size=None):
+ if isinstance(size, (list, tuple)):
+ return size[::-1]
+ else:
+ return get_size_with_aspect_ratio(image_size, size, max_size)
+
+ if square:
+ size = size, size
+ else:
+ cur_size = (
+ datapoint.frames[index].data.size()[-2:][::-1]
+ if v2
+ else datapoint.frames[index].data.size
+ )
+ size = get_size(cur_size, size, max_size)
+
+ old_size = (
+ datapoint.frames[index].data.size()[-2:][::-1]
+ if v2
+ else datapoint.frames[index].data.size
+ )
+ if v2:
+ datapoint.frames[index].data = Fv2.resize(
+ datapoint.frames[index].data, size, antialias=True
+ )
+ else:
+ datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
+
+ new_size = (
+ datapoint.frames[index].data.size()[-2:][::-1]
+ if v2
+ else datapoint.frames[index].data.size
+ )
+
+ for obj in datapoint.frames[index].objects:
+ if obj.segment is not None:
+ obj.segment = F.resize(obj.segment[None, None], size).squeeze()
+
+ h, w = size
+ datapoint.frames[index].size = (h, w)
+ return datapoint
+
+
+def pad(datapoint, index, padding, v2=False):
+ old_h, old_w = datapoint.frames[index].size
+ h, w = old_h, old_w
+ if len(padding) == 2:
+ # assumes that we only pad on the bottom right corners
+ datapoint.frames[index].data = F.pad(
+ datapoint.frames[index].data, (0, 0, padding[0], padding[1])
+ )
+ h += padding[1]
+ w += padding[0]
+ else:
+ # left, top, right, bottom
+ datapoint.frames[index].data = F.pad(
+ datapoint.frames[index].data,
+ (padding[0], padding[1], padding[2], padding[3]),
+ )
+ h += padding[1] + padding[3]
+ w += padding[0] + padding[2]
+
+ datapoint.frames[index].size = (h, w)
+
+ for obj in datapoint.frames[index].objects:
+ if obj.segment is not None:
+ if v2:
+ if len(padding) == 2:
+ obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
+ else:
+ obj.segment = Fv2.pad(obj.segment, tuple(padding))
+ else:
+ if len(padding) == 2:
+ obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
+ else:
+ obj.segment = F.pad(obj.segment, tuple(padding))
+ return datapoint
+
+
+class RandomHorizontalFlip:
+ def __init__(self, consistent_transform, p=0.5):
+ self.p = p
+ self.consistent_transform = consistent_transform
+
+ def __call__(self, datapoint, **kwargs):
+ if self.consistent_transform:
+ if random.random() < self.p:
+ for i in range(len(datapoint.frames)):
+ datapoint = hflip(datapoint, i)
+ return datapoint
+ for i in range(len(datapoint.frames)):
+ if random.random() < self.p:
+ datapoint = hflip(datapoint, i)
+ return datapoint
+
+
+class RandomResizeAPI:
+ def __init__(
+ self, sizes, consistent_transform, max_size=None, square=False, v2=False
+ ):
+ if isinstance(sizes, int):
+ sizes = (sizes,)
+ assert isinstance(sizes, Iterable)
+ self.sizes = list(sizes)
+ self.max_size = max_size
+ self.square = square
+ self.consistent_transform = consistent_transform
+ self.v2 = v2
+
+ def __call__(self, datapoint, **kwargs):
+ if self.consistent_transform:
+ size = random.choice(self.sizes)
+ for i in range(len(datapoint.frames)):
+ datapoint = resize(
+ datapoint, i, size, self.max_size, square=self.square, v2=self.v2
+ )
+ return datapoint
+ for i in range(len(datapoint.frames)):
+ size = random.choice(self.sizes)
+ datapoint = resize(
+ datapoint, i, size, self.max_size, square=self.square, v2=self.v2
+ )
+ return datapoint
+
+
+class ToTensorAPI:
+ def __init__(self, v2=False):
+ self.v2 = v2
+
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
+ for img in datapoint.frames:
+ if self.v2:
+ img.data = Fv2.to_image_tensor(img.data)
+ else:
+ img.data = F.to_tensor(img.data)
+ return datapoint
+
+
+class NormalizeAPI:
+ def __init__(self, mean, std, v2=False):
+ self.mean = mean
+ self.std = std
+ self.v2 = v2
+
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
+ for img in datapoint.frames:
+ if self.v2:
+ img.data = Fv2.convert_image_dtype(img.data, torch.float32)
+ img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
+ else:
+ img.data = F.normalize(img.data, mean=self.mean, std=self.std)
+
+ return datapoint
+
+
+class ComposeAPI:
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, datapoint, **kwargs):
+ for t in self.transforms:
+ datapoint = t(datapoint, **kwargs)
+ return datapoint
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
+
+
+class RandomGrayscale:
+ def __init__(self, consistent_transform, p=0.5):
+ self.p = p
+ self.consistent_transform = consistent_transform
+ self.Grayscale = T.Grayscale(num_output_channels=3)
+
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
+ if self.consistent_transform:
+ if random.random() < self.p:
+ for img in datapoint.frames:
+ img.data = self.Grayscale(img.data)
+ return datapoint
+ for img in datapoint.frames:
+ if random.random() < self.p:
+ img.data = self.Grayscale(img.data)
+ return datapoint
+
+
+class ColorJitter:
+ def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
+ self.consistent_transform = consistent_transform
+ self.brightness = (
+ brightness
+ if isinstance(brightness, list)
+ else [max(0, 1 - brightness), 1 + brightness]
+ )
+ self.contrast = (
+ contrast
+ if isinstance(contrast, list)
+ else [max(0, 1 - contrast), 1 + contrast]
+ )
+ self.saturation = (
+ saturation
+ if isinstance(saturation, list)
+ else [max(0, 1 - saturation), 1 + saturation]
+ )
+ self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
+
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
+ if self.consistent_transform:
+ # Create a color jitter transformation params
+ (
+ fn_idx,
+ brightness_factor,
+ contrast_factor,
+ saturation_factor,
+ hue_factor,
+ ) = T.ColorJitter.get_params(
+ self.brightness, self.contrast, self.saturation, self.hue
+ )
+ for img in datapoint.frames:
+ if not self.consistent_transform:
+ (
+ fn_idx,
+ brightness_factor,
+ contrast_factor,
+ saturation_factor,
+ hue_factor,
+ ) = T.ColorJitter.get_params(
+ self.brightness, self.contrast, self.saturation, self.hue
+ )
+ for fn_id in fn_idx:
+ if fn_id == 0 and brightness_factor is not None:
+ img.data = F.adjust_brightness(img.data, brightness_factor)
+ elif fn_id == 1 and contrast_factor is not None:
+ img.data = F.adjust_contrast(img.data, contrast_factor)
+ elif fn_id == 2 and saturation_factor is not None:
+ img.data = F.adjust_saturation(img.data, saturation_factor)
+ elif fn_id == 3 and hue_factor is not None:
+ img.data = F.adjust_hue(img.data, hue_factor)
+ return datapoint
+
+
+class RandomAffine:
+ def __init__(
+ self,
+ degrees,
+ consistent_transform,
+ scale=None,
+ translate=None,
+ shear=None,
+ image_mean=(123, 116, 103),
+ log_warning=True,
+ num_tentatives=1,
+ image_interpolation="bicubic",
+ ):
+ """
+ The mask is required for this transform.
+ if consistent_transform if True, then the same random affine is applied to all frames and masks.
+ """
+ self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
+ self.scale = scale
+ self.shear = (
+ shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
+ )
+ self.translate = translate
+ self.fill_img = image_mean
+ self.consistent_transform = consistent_transform
+ self.log_warning = log_warning
+ self.num_tentatives = num_tentatives
+
+ if image_interpolation == "bicubic":
+ self.image_interpolation = InterpolationMode.BICUBIC
+ elif image_interpolation == "bilinear":
+ self.image_interpolation = InterpolationMode.BILINEAR
+ else:
+ raise NotImplementedError
+
+ def __call__(self, datapoint: VideoDatapoint, **kwargs):
+ for _tentative in range(self.num_tentatives):
+ res = self.transform_datapoint(datapoint)
+ if res is not None:
+ return res
+
+ if self.log_warning:
+ logging.warning(
+ f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
+ )
+ return datapoint
+
+ def transform_datapoint(self, datapoint: VideoDatapoint):
+ _, height, width = F.get_dimensions(datapoint.frames[0].data)
+ img_size = [width, height]
+
+ if self.consistent_transform:
+ # Create a random affine transformation
+ affine_params = T.RandomAffine.get_params(
+ degrees=self.degrees,
+ translate=self.translate,
+ scale_ranges=self.scale,
+ shears=self.shear,
+ img_size=img_size,
+ )
+
+ for img_idx, img in enumerate(datapoint.frames):
+ this_masks = [
+ obj.segment.unsqueeze(0) if obj.segment is not None else None
+ for obj in img.objects
+ ]
+ if not self.consistent_transform:
+ # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
+ affine_params = T.RandomAffine.get_params(
+ degrees=self.degrees,
+ translate=self.translate,
+ scale_ranges=self.scale,
+ shears=self.shear,
+ img_size=img_size,
+ )
+
+ transformed_bboxes, transformed_masks = [], []
+ for i in range(len(img.objects)):
+ if this_masks[i] is None:
+ transformed_masks.append(None)
+ # Dummy bbox for a dummy target
+ transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
+ else:
+ transformed_mask = F.affine(
+ this_masks[i],
+ *affine_params,
+ interpolation=InterpolationMode.NEAREST,
+ fill=0.0,
+ )
+ if img_idx == 0 and transformed_mask.max() == 0:
+ # We are dealing with a video and the object is not visible in the first frame
+ # Return the datapoint without transformation
+ return None
+ transformed_masks.append(transformed_mask.squeeze())
+
+ for i in range(len(img.objects)):
+ img.objects[i].segment = transformed_masks[i]
+
+ img.data = F.affine(
+ img.data,
+ *affine_params,
+ interpolation=self.image_interpolation,
+ fill=self.fill_img,
+ )
+ return datapoint
+
+
+def random_mosaic_frame(
+ datapoint,
+ index,
+ grid_h,
+ grid_w,
+ target_grid_y,
+ target_grid_x,
+ should_hflip,
+):
+ # Step 1: downsize the images and paste them into a mosaic
+ image_data = datapoint.frames[index].data
+ is_pil = isinstance(image_data, PILImage.Image)
+ if is_pil:
+ H_im = image_data.height
+ W_im = image_data.width
+ image_data_output = PILImage.new("RGB", (W_im, H_im))
+ else:
+ H_im = image_data.size(-2)
+ W_im = image_data.size(-1)
+ image_data_output = torch.zeros_like(image_data)
+
+ downsize_cache = {}
+ for grid_y in range(grid_h):
+ for grid_x in range(grid_w):
+ y_offset_b = grid_y * H_im // grid_h
+ x_offset_b = grid_x * W_im // grid_w
+ y_offset_e = (grid_y + 1) * H_im // grid_h
+ x_offset_e = (grid_x + 1) * W_im // grid_w
+ H_im_downsize = y_offset_e - y_offset_b
+ W_im_downsize = x_offset_e - x_offset_b
+
+ if (H_im_downsize, W_im_downsize) in downsize_cache:
+ image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
+ else:
+ image_data_downsize = F.resize(
+ image_data,
+ size=(H_im_downsize, W_im_downsize),
+ interpolation=InterpolationMode.BILINEAR,
+ antialias=True, # antialiasing for downsizing
+ )
+ downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
+ if should_hflip[grid_y, grid_x].item():
+ image_data_downsize = F.hflip(image_data_downsize)
+
+ if is_pil:
+ image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
+ else:
+ image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
+ image_data_downsize
+ )
+
+ datapoint.frames[index].data = image_data_output
+
+ # Step 2: downsize the masks and paste them into the target grid of the mosaic
+ for obj in datapoint.frames[index].objects:
+ if obj.segment is None:
+ continue
+ assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
+ segment_output = torch.zeros_like(obj.segment)
+
+ target_y_offset_b = target_grid_y * H_im // grid_h
+ target_x_offset_b = target_grid_x * W_im // grid_w
+ target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
+ target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
+ target_H_im_downsize = target_y_offset_e - target_y_offset_b
+ target_W_im_downsize = target_x_offset_e - target_x_offset_b
+
+ segment_downsize = F.resize(
+ obj.segment[None, None],
+ size=(target_H_im_downsize, target_W_im_downsize),
+ interpolation=InterpolationMode.BILINEAR,
+ antialias=True, # antialiasing for downsizing
+ )[0, 0]
+ if should_hflip[target_grid_y, target_grid_x].item():
+ segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
+
+ segment_output[
+ target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
+ ] = segment_downsize
+ obj.segment = segment_output
+
+ return datapoint
+
+
+class RandomMosaicVideoAPI:
+ def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
+ self.prob = prob
+ self.grid_h = grid_h
+ self.grid_w = grid_w
+ self.use_random_hflip = use_random_hflip
+
+ def __call__(self, datapoint, **kwargs):
+ if random.random() > self.prob:
+ return datapoint
+
+ # select a random location to place the target mask in the mosaic
+ target_grid_y = random.randint(0, self.grid_h - 1)
+ target_grid_x = random.randint(0, self.grid_w - 1)
+ # whether to flip each grid in the mosaic horizontally
+ if self.use_random_hflip:
+ should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
+ else:
+ should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
+ for i in range(len(datapoint.frames)):
+ datapoint = random_mosaic_frame(
+ datapoint,
+ i,
+ grid_h=self.grid_h,
+ grid_w=self.grid_w,
+ target_grid_y=target_grid_y,
+ target_grid_x=target_grid_x,
+ should_hflip=should_hflip,
+ )
+
+ return datapoint
diff --git a/phantom/submodules/sam2/training/dataset/utils.py b/phantom/submodules/sam2/training/dataset/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a658df234c3dcf74404f844b5be793b0545485ed
--- /dev/null
+++ b/phantom/submodules/sam2/training/dataset/utils.py
@@ -0,0 +1,104 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
+
+from typing import Iterable
+
+import torch
+from torch.utils.data import (
+ ConcatDataset as TorchConcatDataset,
+ Dataset,
+ Subset as TorchSubset,
+)
+
+
+class ConcatDataset(TorchConcatDataset):
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
+ super(ConcatDataset, self).__init__(datasets)
+
+ self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
+
+ def set_epoch(self, epoch: int):
+ for dataset in self.datasets:
+ if hasattr(dataset, "epoch"):
+ dataset.epoch = epoch
+ if hasattr(dataset, "set_epoch"):
+ dataset.set_epoch(epoch)
+
+
+class Subset(TorchSubset):
+ def __init__(self, dataset, indices) -> None:
+ super(Subset, self).__init__(dataset, indices)
+
+ self.repeat_factors = dataset.repeat_factors[indices]
+ assert len(indices) == len(self.repeat_factors)
+
+
+# Adapted from Detectron2
+class RepeatFactorWrapper(Dataset):
+ """
+ Thin wrapper around a dataset to implement repeat factor sampling.
+ The underlying dataset must have a repeat_factors member to indicate the per-image factor.
+ Set it to uniformly ones to disable repeat factor sampling
+ """
+
+ def __init__(self, dataset, seed: int = 0):
+ self.dataset = dataset
+ self.epoch_ids = None
+ self._seed = seed
+
+ # Split into whole number (_int_part) and fractional (_frac_part) parts.
+ self._int_part = torch.trunc(dataset.repeat_factors)
+ self._frac_part = dataset.repeat_factors - self._int_part
+
+ def _get_epoch_indices(self, generator):
+ """
+ Create a list of dataset indices (with repeats) to use for one epoch.
+
+ Args:
+ generator (torch.Generator): pseudo random number generator used for
+ stochastic rounding.
+
+ Returns:
+ torch.Tensor: list of dataset indices to use in one epoch. Each index
+ is repeated based on its calculated repeat factor.
+ """
+ # Since repeat factors are fractional, we use stochastic rounding so
+ # that the target repeat factor is achieved in expectation over the
+ # course of training
+ rands = torch.rand(len(self._frac_part), generator=generator)
+ rep_factors = self._int_part + (rands < self._frac_part).float()
+ # Construct a list of indices in which we repeat images as specified
+ indices = []
+ for dataset_index, rep_factor in enumerate(rep_factors):
+ indices.extend([dataset_index] * int(rep_factor.item()))
+ return torch.tensor(indices, dtype=torch.int64)
+
+ def __len__(self):
+ if self.epoch_ids is None:
+ # Here we raise an error instead of returning original len(self.dataset) avoid
+ # accidentally using unwrapped length. Otherwise it's error-prone since the
+ # length changes to `len(self.epoch_ids)`changes after set_epoch is called.
+ raise RuntimeError("please call set_epoch first to get wrapped length")
+ # return len(self.dataset)
+
+ return len(self.epoch_ids)
+
+ def set_epoch(self, epoch: int):
+ g = torch.Generator()
+ g.manual_seed(self._seed + epoch)
+ self.epoch_ids = self._get_epoch_indices(g)
+ if hasattr(self.dataset, "set_epoch"):
+ self.dataset.set_epoch(epoch)
+
+ def __getitem__(self, idx):
+ if self.epoch_ids is None:
+ raise RuntimeError(
+ "Repeat ids haven't been computed. Did you forget to call set_epoch?"
+ )
+
+ return self.dataset[self.epoch_ids[idx]]
diff --git a/phantom/submodules/sam2/training/dataset/vos_dataset.py b/phantom/submodules/sam2/training/dataset/vos_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1e9d39fe184cf0d86fbf22b5385dc05988cab83
--- /dev/null
+++ b/phantom/submodules/sam2/training/dataset/vos_dataset.py
@@ -0,0 +1,162 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import random
+from copy import deepcopy
+
+import numpy as np
+
+import torch
+from iopath.common.file_io import g_pathmgr
+from PIL import Image as PILImage
+from torchvision.datasets.vision import VisionDataset
+
+from training.dataset.vos_raw_dataset import VOSRawDataset
+from training.dataset.vos_sampler import VOSSampler
+from training.dataset.vos_segment_loader import JSONSegmentLoader
+
+from training.utils.data_utils import Frame, Object, VideoDatapoint
+
+MAX_RETRIES = 100
+
+
+class VOSDataset(VisionDataset):
+ def __init__(
+ self,
+ transforms,
+ training: bool,
+ video_dataset: VOSRawDataset,
+ sampler: VOSSampler,
+ multiplier: int,
+ always_target=True,
+ target_segments_available=True,
+ ):
+ self._transforms = transforms
+ self.training = training
+ self.video_dataset = video_dataset
+ self.sampler = sampler
+
+ self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
+ self.repeat_factors *= multiplier
+ print(f"Raw dataset length = {len(self.video_dataset)}")
+
+ self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
+ self.always_target = always_target
+ self.target_segments_available = target_segments_available
+
+ def _get_datapoint(self, idx):
+
+ for retry in range(MAX_RETRIES):
+ try:
+ if isinstance(idx, torch.Tensor):
+ idx = idx.item()
+ # sample a video
+ video, segment_loader = self.video_dataset.get_video(idx)
+ # sample frames and object indices to be used in a datapoint
+ sampled_frms_and_objs = self.sampler.sample(
+ video, segment_loader, epoch=self.curr_epoch
+ )
+ break # Succesfully loaded video
+ except Exception as e:
+ if self.training:
+ logging.warning(
+ f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
+ )
+ idx = random.randrange(0, len(self.video_dataset))
+ else:
+ # Shouldn't fail to load a val video
+ raise e
+
+ datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
+ for transform in self._transforms:
+ datapoint = transform(datapoint, epoch=self.curr_epoch)
+ return datapoint
+
+ def construct(self, video, sampled_frms_and_objs, segment_loader):
+ """
+ Constructs a VideoDatapoint sample to pass to transforms
+ """
+ sampled_frames = sampled_frms_and_objs.frames
+ sampled_object_ids = sampled_frms_and_objs.object_ids
+
+ images = []
+ rgb_images = load_images(sampled_frames)
+ # Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
+ for frame_idx, frame in enumerate(sampled_frames):
+ w, h = rgb_images[frame_idx].size
+ images.append(
+ Frame(
+ data=rgb_images[frame_idx],
+ objects=[],
+ )
+ )
+ # We load the gt segments associated with the current frame
+ if isinstance(segment_loader, JSONSegmentLoader):
+ segments = segment_loader.load(
+ frame.frame_idx, obj_ids=sampled_object_ids
+ )
+ else:
+ segments = segment_loader.load(frame.frame_idx)
+ for obj_id in sampled_object_ids:
+ # Extract the segment
+ if obj_id in segments:
+ assert (
+ segments[obj_id] is not None
+ ), "None targets are not supported"
+ # segment is uint8 and remains uint8 throughout the transforms
+ segment = segments[obj_id].to(torch.uint8)
+ else:
+ # There is no target, we either use a zero mask target or drop this object
+ if not self.always_target:
+ continue
+ segment = torch.zeros(h, w, dtype=torch.uint8)
+
+ images[frame_idx].objects.append(
+ Object(
+ object_id=obj_id,
+ frame_index=frame.frame_idx,
+ segment=segment,
+ )
+ )
+ return VideoDatapoint(
+ frames=images,
+ video_id=video.video_id,
+ size=(h, w),
+ )
+
+ def __getitem__(self, idx):
+ return self._get_datapoint(idx)
+
+ def __len__(self):
+ return len(self.video_dataset)
+
+
+def load_images(frames):
+ all_images = []
+ cache = {}
+ for frame in frames:
+ if frame.data is None:
+ # Load the frame rgb data from file
+ path = frame.image_path
+ if path in cache:
+ all_images.append(deepcopy(all_images[cache[path]]))
+ continue
+ with g_pathmgr.open(path, "rb") as fopen:
+ all_images.append(PILImage.open(fopen).convert("RGB"))
+ cache[path] = len(all_images) - 1
+ else:
+ # The frame rgb data has already been loaded
+ # Convert it to a PILImage
+ all_images.append(tensor_2_PIL(frame.data))
+
+ return all_images
+
+
+def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
+ data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
+ data = data.astype(np.uint8)
+ return PILImage.fromarray(data)
diff --git a/phantom/submodules/sam2/training/dataset/vos_raw_dataset.py b/phantom/submodules/sam2/training/dataset/vos_raw_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..44fe893717a3e3bd85b043baa33d349b52b4b34e
--- /dev/null
+++ b/phantom/submodules/sam2/training/dataset/vos_raw_dataset.py
@@ -0,0 +1,308 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import glob
+import logging
+import os
+from dataclasses import dataclass
+
+from typing import List, Optional
+
+import pandas as pd
+
+import torch
+
+from iopath.common.file_io import g_pathmgr
+
+from omegaconf.listconfig import ListConfig
+
+from training.dataset.vos_segment_loader import (
+ JSONSegmentLoader,
+ MultiplePNGSegmentLoader,
+ PalettisedPNGSegmentLoader,
+ SA1BSegmentLoader,
+)
+
+
+@dataclass
+class VOSFrame:
+ frame_idx: int
+ image_path: str
+ data: Optional[torch.Tensor] = None
+ is_conditioning_only: Optional[bool] = False
+
+
+@dataclass
+class VOSVideo:
+ video_name: str
+ video_id: int
+ frames: List[VOSFrame]
+
+ def __len__(self):
+ return len(self.frames)
+
+
+class VOSRawDataset:
+ def __init__(self):
+ pass
+
+ def get_video(self, idx):
+ raise NotImplementedError()
+
+
+class PNGRawDataset(VOSRawDataset):
+ def __init__(
+ self,
+ img_folder,
+ gt_folder,
+ file_list_txt=None,
+ excluded_videos_list_txt=None,
+ sample_rate=1,
+ is_palette=True,
+ single_object_mode=False,
+ truncate_video=-1,
+ frames_sampling_mult=False,
+ ):
+ self.img_folder = img_folder
+ self.gt_folder = gt_folder
+ self.sample_rate = sample_rate
+ self.is_palette = is_palette
+ self.single_object_mode = single_object_mode
+ self.truncate_video = truncate_video
+
+ # Read the subset defined in file_list_txt
+ if file_list_txt is not None:
+ with g_pathmgr.open(file_list_txt, "r") as f:
+ subset = [os.path.splitext(line.strip())[0] for line in f]
+ else:
+ subset = os.listdir(self.img_folder)
+
+ # Read and process excluded files if provided
+ if excluded_videos_list_txt is not None:
+ with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
+ excluded_files = [os.path.splitext(line.strip())[0] for line in f]
+ else:
+ excluded_files = []
+
+ # Check if it's not in excluded_files
+ self.video_names = sorted(
+ [video_name for video_name in subset if video_name not in excluded_files]
+ )
+
+ if self.single_object_mode:
+ # single object mode
+ self.video_names = sorted(
+ [
+ os.path.join(video_name, obj)
+ for video_name in self.video_names
+ for obj in os.listdir(os.path.join(self.gt_folder, video_name))
+ ]
+ )
+
+ if frames_sampling_mult:
+ video_names_mult = []
+ for video_name in self.video_names:
+ num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
+ video_names_mult.extend([video_name] * num_frames)
+ self.video_names = video_names_mult
+
+ def get_video(self, idx):
+ """
+ Given a VOSVideo object, return the mask tensors.
+ """
+ video_name = self.video_names[idx]
+
+ if self.single_object_mode:
+ video_frame_root = os.path.join(
+ self.img_folder, os.path.dirname(video_name)
+ )
+ else:
+ video_frame_root = os.path.join(self.img_folder, video_name)
+
+ video_mask_root = os.path.join(self.gt_folder, video_name)
+
+ if self.is_palette:
+ segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
+ else:
+ segment_loader = MultiplePNGSegmentLoader(
+ video_mask_root, self.single_object_mode
+ )
+
+ all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
+ if self.truncate_video > 0:
+ all_frames = all_frames[: self.truncate_video]
+ frames = []
+ for _, fpath in enumerate(all_frames[:: self.sample_rate]):
+ fid = int(os.path.basename(fpath).split(".")[0])
+ frames.append(VOSFrame(fid, image_path=fpath))
+ video = VOSVideo(video_name, idx, frames)
+ return video, segment_loader
+
+ def __len__(self):
+ return len(self.video_names)
+
+
+class SA1BRawDataset(VOSRawDataset):
+ def __init__(
+ self,
+ img_folder,
+ gt_folder,
+ file_list_txt=None,
+ excluded_videos_list_txt=None,
+ num_frames=1,
+ mask_area_frac_thresh=1.1, # no filtering by default
+ uncertain_iou=-1, # no filtering by default
+ ):
+ self.img_folder = img_folder
+ self.gt_folder = gt_folder
+ self.num_frames = num_frames
+ self.mask_area_frac_thresh = mask_area_frac_thresh
+ self.uncertain_iou = uncertain_iou # stability score
+
+ # Read the subset defined in file_list_txt
+ if file_list_txt is not None:
+ with g_pathmgr.open(file_list_txt, "r") as f:
+ subset = [os.path.splitext(line.strip())[0] for line in f]
+ else:
+ subset = os.listdir(self.img_folder)
+ subset = [
+ path.split(".")[0] for path in subset if path.endswith(".jpg")
+ ] # remove extension
+
+ # Read and process excluded files if provided
+ if excluded_videos_list_txt is not None:
+ with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
+ excluded_files = [os.path.splitext(line.strip())[0] for line in f]
+ else:
+ excluded_files = []
+
+ # Check if it's not in excluded_files and it exists
+ self.video_names = [
+ video_name for video_name in subset if video_name not in excluded_files
+ ]
+
+ def get_video(self, idx):
+ """
+ Given a VOSVideo object, return the mask tensors.
+ """
+ video_name = self.video_names[idx]
+
+ video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
+ video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
+
+ segment_loader = SA1BSegmentLoader(
+ video_mask_path,
+ mask_area_frac_thresh=self.mask_area_frac_thresh,
+ video_frame_path=video_frame_path,
+ uncertain_iou=self.uncertain_iou,
+ )
+
+ frames = []
+ for frame_idx in range(self.num_frames):
+ frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
+ video_name = video_name.split("_")[-1] # filename is sa_{int}
+ # video id needs to be image_id to be able to load correct annotation file during eval
+ video = VOSVideo(video_name, int(video_name), frames)
+ return video, segment_loader
+
+ def __len__(self):
+ return len(self.video_names)
+
+
+class JSONRawDataset(VOSRawDataset):
+ """
+ Dataset where the annotation in the format of SA-V json files
+ """
+
+ def __init__(
+ self,
+ img_folder,
+ gt_folder,
+ file_list_txt=None,
+ excluded_videos_list_txt=None,
+ sample_rate=1,
+ rm_unannotated=True,
+ ann_every=1,
+ frames_fps=24,
+ ):
+ self.gt_folder = gt_folder
+ self.img_folder = img_folder
+ self.sample_rate = sample_rate
+ self.rm_unannotated = rm_unannotated
+ self.ann_every = ann_every
+ self.frames_fps = frames_fps
+
+ # Read and process excluded files if provided
+ excluded_files = []
+ if excluded_videos_list_txt is not None:
+ if isinstance(excluded_videos_list_txt, str):
+ excluded_videos_lists = [excluded_videos_list_txt]
+ elif isinstance(excluded_videos_list_txt, ListConfig):
+ excluded_videos_lists = list(excluded_videos_list_txt)
+ else:
+ raise NotImplementedError
+
+ for excluded_videos_list_txt in excluded_videos_lists:
+ with open(excluded_videos_list_txt, "r") as f:
+ excluded_files.extend(
+ [os.path.splitext(line.strip())[0] for line in f]
+ )
+ excluded_files = set(excluded_files)
+
+ # Read the subset defined in file_list_txt
+ if file_list_txt is not None:
+ with g_pathmgr.open(file_list_txt, "r") as f:
+ subset = [os.path.splitext(line.strip())[0] for line in f]
+ else:
+ subset = os.listdir(self.img_folder)
+
+ self.video_names = sorted(
+ [video_name for video_name in subset if video_name not in excluded_files]
+ )
+
+ def get_video(self, video_idx):
+ """
+ Given a VOSVideo object, return the mask tensors.
+ """
+ video_name = self.video_names[video_idx]
+ video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
+ segment_loader = JSONSegmentLoader(
+ video_json_path=video_json_path,
+ ann_every=self.ann_every,
+ frames_fps=self.frames_fps,
+ )
+
+ frame_ids = [
+ int(os.path.splitext(frame_name)[0])
+ for frame_name in sorted(
+ os.listdir(os.path.join(self.img_folder, video_name))
+ )
+ ]
+
+ frames = [
+ VOSFrame(
+ frame_id,
+ image_path=os.path.join(
+ self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)
+ ),
+ )
+ for frame_id in frame_ids[:: self.sample_rate]
+ ]
+
+ if self.rm_unannotated:
+ # Eliminate the frames that have not been annotated
+ valid_frame_ids = [
+ i * segment_loader.ann_every
+ for i, annot in enumerate(segment_loader.frame_annots)
+ if annot is not None and None not in annot
+ ]
+ frames = [f for f in frames if f.frame_idx in valid_frame_ids]
+
+ video = VOSVideo(video_name, video_idx, frames)
+ return video, segment_loader
+
+ def __len__(self):
+ return len(self.video_names)
diff --git a/phantom/submodules/sam2/training/dataset/vos_sampler.py b/phantom/submodules/sam2/training/dataset/vos_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ad84b759d0f66191a84017d17140d128b634ca0
--- /dev/null
+++ b/phantom/submodules/sam2/training/dataset/vos_sampler.py
@@ -0,0 +1,105 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+from dataclasses import dataclass
+from typing import List
+
+from training.dataset.vos_segment_loader import LazySegments
+
+MAX_RETRIES = 1000
+
+
+@dataclass
+class SampledFramesAndObjects:
+ frames: List[int]
+ object_ids: List[int]
+
+
+class VOSSampler:
+ def __init__(self, sort_frames=True):
+ # frames are ordered by frame id when sort_frames is True
+ self.sort_frames = sort_frames
+
+ def sample(self, video):
+ raise NotImplementedError()
+
+
+class RandomUniformSampler(VOSSampler):
+ def __init__(
+ self,
+ num_frames,
+ max_num_objects,
+ reverse_time_prob=0.0,
+ ):
+ self.num_frames = num_frames
+ self.max_num_objects = max_num_objects
+ self.reverse_time_prob = reverse_time_prob
+
+ def sample(self, video, segment_loader, epoch=None):
+
+ for retry in range(MAX_RETRIES):
+ if len(video.frames) < self.num_frames:
+ raise Exception(
+ f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
+ )
+ start = random.randrange(0, len(video.frames) - self.num_frames + 1)
+ frames = [video.frames[start + step] for step in range(self.num_frames)]
+ if random.uniform(0, 1) < self.reverse_time_prob:
+ # Reverse time
+ frames = frames[::-1]
+
+ # Get first frame object ids
+ visible_object_ids = []
+ loaded_segms = segment_loader.load(frames[0].frame_idx)
+ if isinstance(loaded_segms, LazySegments):
+ # LazySegments for SA1BRawDataset
+ visible_object_ids = list(loaded_segms.keys())
+ else:
+ for object_id, segment in segment_loader.load(
+ frames[0].frame_idx
+ ).items():
+ if segment.sum():
+ visible_object_ids.append(object_id)
+
+ # First frame needs to have at least a target to track
+ if len(visible_object_ids) > 0:
+ break
+ if retry >= MAX_RETRIES - 1:
+ raise Exception("No visible objects")
+
+ object_ids = random.sample(
+ visible_object_ids,
+ min(len(visible_object_ids), self.max_num_objects),
+ )
+ return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
+
+
+class EvalSampler(VOSSampler):
+ """
+ VOS Sampler for evaluation: sampling all the frames and all the objects in a video
+ """
+
+ def __init__(
+ self,
+ ):
+ super().__init__()
+
+ def sample(self, video, segment_loader, epoch=None):
+ """
+ Sampling all the frames and all the objects
+ """
+ if self.sort_frames:
+ # ordered by frame id
+ frames = sorted(video.frames, key=lambda x: x.frame_idx)
+ else:
+ # use the original order
+ frames = video.frames
+ object_ids = segment_loader.load(frames[0].frame_idx).keys()
+ if len(object_ids) == 0:
+ raise Exception("First frame of the video has no objects")
+
+ return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
diff --git a/phantom/submodules/sam2/training/dataset/vos_segment_loader.py b/phantom/submodules/sam2/training/dataset/vos_segment_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..27e17010cc8b010e103c3ac399689d80da7cfde9
--- /dev/null
+++ b/phantom/submodules/sam2/training/dataset/vos_segment_loader.py
@@ -0,0 +1,300 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import glob
+import json
+import os
+
+import numpy as np
+import pandas as pd
+import torch
+
+from PIL import Image as PILImage
+
+try:
+ from pycocotools import mask as mask_utils
+except:
+ pass
+
+
+class JSONSegmentLoader:
+ def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
+ # Annotations in the json are provided every ann_every th frame
+ self.ann_every = ann_every
+ # Ids of the objects to consider when sampling this video
+ self.valid_obj_ids = valid_obj_ids
+ with open(video_json_path, "r") as f:
+ data = json.load(f)
+ if isinstance(data, list):
+ self.frame_annots = data
+ elif isinstance(data, dict):
+ masklet_field_name = "masklet" if "masklet" in data else "masks"
+ self.frame_annots = data[masklet_field_name]
+ if "fps" in data:
+ if isinstance(data["fps"], list):
+ annotations_fps = int(data["fps"][0])
+ else:
+ annotations_fps = int(data["fps"])
+ assert frames_fps % annotations_fps == 0
+ self.ann_every = frames_fps // annotations_fps
+ else:
+ raise NotImplementedError
+
+ def load(self, frame_id, obj_ids=None):
+ assert frame_id % self.ann_every == 0
+ rle_mask = self.frame_annots[frame_id // self.ann_every]
+
+ valid_objs_ids = set(range(len(rle_mask)))
+ if self.valid_obj_ids is not None:
+ # Remove the masklets that have been filtered out for this video
+ valid_objs_ids &= set(self.valid_obj_ids)
+ if obj_ids is not None:
+ # Only keep the objects that have been sampled
+ valid_objs_ids &= set(obj_ids)
+ valid_objs_ids = sorted(list(valid_objs_ids))
+
+ # Construct rle_masks_filtered that only contains the rle masks we are interested in
+ id_2_idx = {}
+ rle_mask_filtered = []
+ for obj_id in valid_objs_ids:
+ if rle_mask[obj_id] is not None:
+ id_2_idx[obj_id] = len(rle_mask_filtered)
+ rle_mask_filtered.append(rle_mask[obj_id])
+ else:
+ id_2_idx[obj_id] = None
+
+ # Decode the masks
+ raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
+ 2, 0, 1
+ ) # (num_obj, h, w)
+ segments = {}
+ for obj_id in valid_objs_ids:
+ if id_2_idx[obj_id] is None:
+ segments[obj_id] = None
+ else:
+ idx = id_2_idx[obj_id]
+ segments[obj_id] = raw_segments[idx]
+ return segments
+
+ def get_valid_obj_frames_ids(self, num_frames_min=None):
+ # For each object, find all the frames with a valid (not None) mask
+ num_objects = len(self.frame_annots[0])
+
+ # The result dict associates each obj_id with the id of its valid frames
+ res = {obj_id: [] for obj_id in range(num_objects)}
+
+ for annot_idx, annot in enumerate(self.frame_annots):
+ for obj_id in range(num_objects):
+ if annot[obj_id] is not None:
+ res[obj_id].append(int(annot_idx * self.ann_every))
+
+ if num_frames_min is not None:
+ # Remove masklets that have less than num_frames_min valid masks
+ for obj_id, valid_frames in list(res.items()):
+ if len(valid_frames) < num_frames_min:
+ res.pop(obj_id)
+
+ return res
+
+
+class PalettisedPNGSegmentLoader:
+ def __init__(self, video_png_root):
+ """
+ SegmentLoader for datasets with masks stored as palettised PNGs.
+ video_png_root: the folder contains all the masks stored in png
+ """
+ self.video_png_root = video_png_root
+ # build a mapping from frame id to their PNG mask path
+ # note that in some datasets, the PNG paths could have more
+ # than 5 digits, e.g. "00000000.png" instead of "00000.png"
+ png_filenames = os.listdir(self.video_png_root)
+ self.frame_id_to_png_filename = {}
+ for filename in png_filenames:
+ frame_id, _ = os.path.splitext(filename)
+ self.frame_id_to_png_filename[int(frame_id)] = filename
+
+ def load(self, frame_id):
+ """
+ load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
+ Args:
+ frame_id: int, define the mask path
+ Return:
+ binary_segments: dict
+ """
+ # check the path
+ mask_path = os.path.join(
+ self.video_png_root, self.frame_id_to_png_filename[frame_id]
+ )
+
+ # load the mask
+ masks = PILImage.open(mask_path).convert("P")
+ masks = np.array(masks)
+
+ object_id = pd.unique(masks.flatten())
+ object_id = object_id[object_id != 0] # remove background (0)
+
+ # convert into N binary segmentation masks
+ binary_segments = {}
+ for i in object_id:
+ bs = masks == i
+ binary_segments[i] = torch.from_numpy(bs)
+
+ return binary_segments
+
+ def __len__(self):
+ return
+
+
+class MultiplePNGSegmentLoader:
+ def __init__(self, video_png_root, single_object_mode=False):
+ """
+ video_png_root: the folder contains all the masks stored in png
+ single_object_mode: whether to load only a single object at a time
+ """
+ self.video_png_root = video_png_root
+ self.single_object_mode = single_object_mode
+ # read a mask to know the resolution of the video
+ if self.single_object_mode:
+ tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
+ else:
+ tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
+ tmp_mask = np.array(PILImage.open(tmp_mask_path))
+ self.H = tmp_mask.shape[0]
+ self.W = tmp_mask.shape[1]
+ if self.single_object_mode:
+ self.obj_id = (
+ int(video_png_root.split("/")[-1]) + 1
+ ) # offset by 1 as bg is 0
+ else:
+ self.obj_id = None
+
+ def load(self, frame_id):
+ if self.single_object_mode:
+ return self._load_single_png(frame_id)
+ else:
+ return self._load_multiple_pngs(frame_id)
+
+ def _load_single_png(self, frame_id):
+ """
+ load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
+ Args:
+ frame_id: int, define the mask path
+ Return:
+ binary_segments: dict
+ """
+ mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
+ binary_segments = {}
+
+ if os.path.exists(mask_path):
+ mask = np.array(PILImage.open(mask_path))
+ else:
+ # if png doesn't exist, empty mask
+ mask = np.zeros((self.H, self.W), dtype=bool)
+ binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
+ return binary_segments
+
+ def _load_multiple_pngs(self, frame_id):
+ """
+ load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
+ Args:
+ frame_id: int, define the mask path
+ Return:
+ binary_segments: dict
+ """
+ # get the path
+ all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
+ num_objects = len(all_objects)
+ assert num_objects > 0
+
+ # load the masks
+ binary_segments = {}
+ for obj_folder in all_objects:
+ # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
+ obj_id = int(obj_folder.split("/")[-1])
+ obj_id = obj_id + 1 # offset 1 as bg is 0
+ mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
+ if os.path.exists(mask_path):
+ mask = np.array(PILImage.open(mask_path))
+ else:
+ mask = np.zeros((self.H, self.W), dtype=bool)
+ binary_segments[obj_id] = torch.from_numpy(mask > 0)
+
+ return binary_segments
+
+ def __len__(self):
+ return
+
+
+class LazySegments:
+ """
+ Only decodes segments that are actually used.
+ """
+
+ def __init__(self):
+ self.segments = {}
+ self.cache = {}
+
+ def __setitem__(self, key, item):
+ self.segments[key] = item
+
+ def __getitem__(self, key):
+ if key in self.cache:
+ return self.cache[key]
+ rle = self.segments[key]
+ mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
+ self.cache[key] = mask
+ return mask
+
+ def __contains__(self, key):
+ return key in self.segments
+
+ def __len__(self):
+ return len(self.segments)
+
+ def keys(self):
+ return self.segments.keys()
+
+
+class SA1BSegmentLoader:
+ def __init__(
+ self,
+ video_mask_path,
+ mask_area_frac_thresh=1.1,
+ video_frame_path=None,
+ uncertain_iou=-1,
+ ):
+ with open(video_mask_path, "r") as f:
+ self.frame_annots = json.load(f)
+
+ if mask_area_frac_thresh <= 1.0:
+ # Lazily read frame
+ orig_w, orig_h = PILImage.open(video_frame_path).size
+ area = orig_w * orig_h
+
+ self.frame_annots = self.frame_annots["annotations"]
+
+ rle_masks = []
+ for frame_annot in self.frame_annots:
+ if not frame_annot["area"] > 0:
+ continue
+ if ("uncertain_iou" in frame_annot) and (
+ frame_annot["uncertain_iou"] < uncertain_iou
+ ):
+ # uncertain_iou is stability score
+ continue
+ if (
+ mask_area_frac_thresh <= 1.0
+ and (frame_annot["area"] / area) >= mask_area_frac_thresh
+ ):
+ continue
+ rle_masks.append(frame_annot["segmentation"])
+
+ self.segments = LazySegments()
+ for i, rle in enumerate(rle_masks):
+ self.segments[i] = rle
+
+ def load(self, frame_idx):
+ return self.segments
diff --git a/phantom/submodules/sam2/training/loss_fns.py b/phantom/submodules/sam2/training/loss_fns.py
new file mode 100644
index 0000000000000000000000000000000000000000..d281b1a9c059771ee0ae3a4d4426f1e445178110
--- /dev/null
+++ b/phantom/submodules/sam2/training/loss_fns.py
@@ -0,0 +1,307 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+from typing import Dict, List
+
+import torch
+import torch.distributed
+import torch.nn as nn
+import torch.nn.functional as F
+
+from training.trainer import CORE_LOSS_KEY
+
+from training.utils.distributed import get_world_size, is_dist_avail_and_initialized
+
+
+def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ num_objects: Number of objects in the batch
+ loss_on_multimask: True if multimask prediction is enabled
+ Returns:
+ Dice loss tensor
+ """
+ inputs = inputs.sigmoid()
+ if loss_on_multimask:
+ # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
+ assert inputs.dim() == 4 and targets.dim() == 4
+ # flatten spatial dimension while keeping multimask channel dimension
+ inputs = inputs.flatten(2)
+ targets = targets.flatten(2)
+ numerator = 2 * (inputs * targets).sum(-1)
+ else:
+ inputs = inputs.flatten(1)
+ numerator = 2 * (inputs * targets).sum(1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ if loss_on_multimask:
+ return loss / num_objects
+ return loss.sum() / num_objects
+
+
+def sigmoid_focal_loss(
+ inputs,
+ targets,
+ num_objects,
+ alpha: float = 0.25,
+ gamma: float = 2,
+ loss_on_multimask=False,
+):
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ num_objects: Number of objects in the batch
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ loss_on_multimask: True if multimask prediction is enabled
+ Returns:
+ focal loss tensor
+ """
+ prob = inputs.sigmoid()
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ if loss_on_multimask:
+ # loss is [N, M, H, W] where M corresponds to multiple predicted masks
+ assert loss.dim() == 4
+ return loss.flatten(2).mean(-1) / num_objects # average over spatial dims
+ return loss.mean(1).sum() / num_objects
+
+
+def iou_loss(
+ inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
+):
+ """
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ pred_ious: A float tensor containing the predicted IoUs scores per mask
+ num_objects: Number of objects in the batch
+ loss_on_multimask: True if multimask prediction is enabled
+ use_l1_loss: Whether to use L1 loss is used instead of MSE loss
+ Returns:
+ IoU loss tensor
+ """
+ assert inputs.dim() == 4 and targets.dim() == 4
+ pred_mask = inputs.flatten(2) > 0
+ gt_mask = targets.flatten(2) > 0
+ area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
+ area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
+ actual_ious = area_i / torch.clamp(area_u, min=1.0)
+
+ if use_l1_loss:
+ loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
+ else:
+ loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
+ if loss_on_multimask:
+ return loss / num_objects
+ return loss.sum() / num_objects
+
+
+class MultiStepMultiMasksAndIous(nn.Module):
+ def __init__(
+ self,
+ weight_dict,
+ focal_alpha=0.25,
+ focal_gamma=2,
+ supervise_all_iou=False,
+ iou_use_l1_loss=False,
+ pred_obj_scores=False,
+ focal_gamma_obj_score=0.0,
+ focal_alpha_obj_score=-1,
+ ):
+ """
+ This class computes the multi-step multi-mask and IoU losses.
+ Args:
+ weight_dict: dict containing weights for focal, dice, iou losses
+ focal_alpha: alpha for sigmoid focal loss
+ focal_gamma: gamma for sigmoid focal loss
+ supervise_all_iou: if True, back-prop iou losses for all predicted masks
+ iou_use_l1_loss: use L1 loss instead of MSE loss for iou
+ pred_obj_scores: if True, compute loss for object scores
+ focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
+ focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
+ """
+
+ super().__init__()
+ self.weight_dict = weight_dict
+ self.focal_alpha = focal_alpha
+ self.focal_gamma = focal_gamma
+ assert "loss_mask" in self.weight_dict
+ assert "loss_dice" in self.weight_dict
+ assert "loss_iou" in self.weight_dict
+ if "loss_class" not in self.weight_dict:
+ self.weight_dict["loss_class"] = 0.0
+
+ self.focal_alpha_obj_score = focal_alpha_obj_score
+ self.focal_gamma_obj_score = focal_gamma_obj_score
+ self.supervise_all_iou = supervise_all_iou
+ self.iou_use_l1_loss = iou_use_l1_loss
+ self.pred_obj_scores = pred_obj_scores
+
+ def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
+ assert len(outs_batch) == len(targets_batch)
+ num_objects = torch.tensor(
+ (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
+ ) # Number of objects is fixed within a batch
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_objects)
+ num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()
+
+ losses = defaultdict(int)
+ for outs, targets in zip(outs_batch, targets_batch):
+ cur_losses = self._forward(outs, targets, num_objects)
+ for k, v in cur_losses.items():
+ losses[k] += v
+
+ return losses
+
+ def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
+ """
+ Compute the losses related to the masks: the focal loss and the dice loss.
+ and also the MAE or MSE loss between predicted IoUs and actual IoUs.
+
+ Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
+ of shape [N, M, H, W], where M could be 1 or larger, corresponding to
+ one or multiple predicted masks from a click.
+
+ We back-propagate focal, dice losses only on the prediction channel
+ with the lowest focal+dice loss between predicted mask and ground-truth.
+ If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
+ """
+
+ target_masks = targets.unsqueeze(1).float()
+ assert target_masks.dim() == 4 # [N, 1, H, W]
+ src_masks_list = outputs["multistep_pred_multimasks_high_res"]
+ ious_list = outputs["multistep_pred_ious"]
+ object_score_logits_list = outputs["multistep_object_score_logits"]
+
+ assert len(src_masks_list) == len(ious_list)
+ assert len(object_score_logits_list) == len(ious_list)
+
+ # accumulate the loss over prediction steps
+ losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
+ for src_masks, ious, object_score_logits in zip(
+ src_masks_list, ious_list, object_score_logits_list
+ ):
+ self._update_losses(
+ losses, src_masks, target_masks, ious, num_objects, object_score_logits
+ )
+ losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
+ return losses
+
+ def _update_losses(
+ self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
+ ):
+ target_masks = target_masks.expand_as(src_masks)
+ # get focal, dice and iou loss on all output masks in a prediction step
+ loss_multimask = sigmoid_focal_loss(
+ src_masks,
+ target_masks,
+ num_objects,
+ alpha=self.focal_alpha,
+ gamma=self.focal_gamma,
+ loss_on_multimask=True,
+ )
+ loss_multidice = dice_loss(
+ src_masks, target_masks, num_objects, loss_on_multimask=True
+ )
+ if not self.pred_obj_scores:
+ loss_class = torch.tensor(
+ 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
+ )
+ target_obj = torch.ones(
+ loss_multimask.shape[0],
+ 1,
+ dtype=loss_multimask.dtype,
+ device=loss_multimask.device,
+ )
+ else:
+ target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
+ ..., None
+ ].float()
+ loss_class = sigmoid_focal_loss(
+ object_score_logits,
+ target_obj,
+ num_objects,
+ alpha=self.focal_alpha_obj_score,
+ gamma=self.focal_gamma_obj_score,
+ )
+
+ loss_multiiou = iou_loss(
+ src_masks,
+ target_masks,
+ ious,
+ num_objects,
+ loss_on_multimask=True,
+ use_l1_loss=self.iou_use_l1_loss,
+ )
+ assert loss_multimask.dim() == 2
+ assert loss_multidice.dim() == 2
+ assert loss_multiiou.dim() == 2
+ if loss_multimask.size(1) > 1:
+ # take the mask indices with the smallest focal + dice loss for back propagation
+ loss_combo = (
+ loss_multimask * self.weight_dict["loss_mask"]
+ + loss_multidice * self.weight_dict["loss_dice"]
+ )
+ best_loss_inds = torch.argmin(loss_combo, dim=-1)
+ batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
+ loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
+ loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
+ # calculate the iou prediction and slot losses only in the index
+ # with the minimum loss for each mask (to be consistent w/ SAM)
+ if self.supervise_all_iou:
+ loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
+ else:
+ loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
+ else:
+ loss_mask = loss_multimask
+ loss_dice = loss_multidice
+ loss_iou = loss_multiiou
+
+ # backprop focal, dice and iou loss only if obj present
+ loss_mask = loss_mask * target_obj
+ loss_dice = loss_dice * target_obj
+ loss_iou = loss_iou * target_obj
+
+ # sum over batch dimension (note that the losses are already divided by num_objects)
+ losses["loss_mask"] += loss_mask.sum()
+ losses["loss_dice"] += loss_dice.sum()
+ losses["loss_iou"] += loss_iou.sum()
+ losses["loss_class"] += loss_class
+
+ def reduce_loss(self, losses):
+ reduced_loss = 0.0
+ for loss_key, weight in self.weight_dict.items():
+ if loss_key not in losses:
+ raise ValueError(f"{type(self)} doesn't compute {loss_key}")
+ if weight != 0:
+ reduced_loss += losses[loss_key] * weight
+
+ return reduced_loss
diff --git a/phantom/submodules/sam2/training/model/__init__.py b/phantom/submodules/sam2/training/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/phantom/submodules/sam2/training/model/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/phantom/submodules/sam2/training/model/sam2.py b/phantom/submodules/sam2/training/model/sam2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef7567c4dc99942d48e5890529ba9e3ca265e02d
--- /dev/null
+++ b/phantom/submodules/sam2/training/model/sam2.py
@@ -0,0 +1,541 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+
+import numpy as np
+import torch
+import torch.distributed
+from sam2.modeling.sam2_base import SAM2Base
+from sam2.modeling.sam2_utils import (
+ get_1d_sine_pe,
+ get_next_point,
+ sample_box_points,
+ select_closest_cond_frames,
+)
+
+from sam2.utils.misc import concat_points
+
+from training.utils.data_utils import BatchedVideoDatapoint
+
+
+class SAM2Train(SAM2Base):
+ def __init__(
+ self,
+ image_encoder,
+ memory_attention=None,
+ memory_encoder=None,
+ prob_to_use_pt_input_for_train=0.0,
+ prob_to_use_pt_input_for_eval=0.0,
+ prob_to_use_box_input_for_train=0.0,
+ prob_to_use_box_input_for_eval=0.0,
+ # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
+ num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
+ num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
+ rand_frames_to_correct_for_train=False,
+ rand_frames_to_correct_for_eval=False,
+ # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
+ # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
+ # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
+ # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
+ # these are initial conditioning frames because as we track the video, more conditioning frames might be added
+ # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
+ num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
+ num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
+ rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
+ rand_init_cond_frames_for_eval=False,
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
+ add_all_frames_to_correct_as_cond=False,
+ # how many additional correction points to sample (on each frame selected to be corrected)
+ # note that the first frame receives an initial input click (in addition to any correction clicks)
+ num_correction_pt_per_frame=7,
+ # method for point sampling during evaluation
+ # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
+ # default to "center" to be consistent with evaluation in the SAM paper
+ pt_sampling_for_eval="center",
+ # During training, we optionally allow sampling the correction points from GT regions
+ # instead of the prediction error regions with a small probability. This might allow the
+ # model to overfit less to the error regions in training datasets
+ prob_to_sample_from_gt_for_train=0.0,
+ use_act_ckpt_iterative_pt_sampling=False,
+ # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
+ # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
+ forward_backbone_per_frame_for_eval=False,
+ freeze_image_encoder=False,
+ **kwargs,
+ ):
+ super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
+ self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
+ self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
+
+ # Point sampler and conditioning frames
+ self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
+ self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
+ self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
+ self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
+ if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
+ logging.info(
+ f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
+ )
+ assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
+ assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
+
+ self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
+ self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
+ self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
+ self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
+ # Initial multi-conditioning frames
+ self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
+ self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
+ self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
+ self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
+ self.num_correction_pt_per_frame = num_correction_pt_per_frame
+ self.pt_sampling_for_eval = pt_sampling_for_eval
+ self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
+ # A random number generator with a fixed initial seed across GPUs
+ self.rng = np.random.default_rng(seed=42)
+
+ if freeze_image_encoder:
+ for p in self.image_encoder.parameters():
+ p.requires_grad = False
+
+ def forward(self, input: BatchedVideoDatapoint):
+ if self.training or not self.forward_backbone_per_frame_for_eval:
+ # precompute image features on all frames before tracking
+ backbone_out = self.forward_image(input.flat_img_batch)
+ else:
+ # defer image feature computation on a frame until it's being tracked
+ backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
+ backbone_out = self.prepare_prompt_inputs(backbone_out, input)
+ previous_stages_out = self.forward_tracking(backbone_out, input)
+
+ return previous_stages_out
+
+ def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
+ """Compute the image backbone features on the fly for the given img_ids."""
+ # Only forward backbone on unique image ids to avoid repetitive computation
+ # (if `img_ids` has only one element, it's already unique so we skip this step).
+ if img_ids.numel() > 1:
+ unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
+ else:
+ unique_img_ids, inv_ids = img_ids, None
+
+ # Compute the image features on those unique image ids
+ image = img_batch[unique_img_ids]
+ backbone_out = self.forward_image(image)
+ (
+ _,
+ vision_feats,
+ vision_pos_embeds,
+ feat_sizes,
+ ) = self._prepare_backbone_features(backbone_out)
+ # Inverse-map image features for `unique_img_ids` to the final image features
+ # for the original input `img_ids`.
+ if inv_ids is not None:
+ image = image[inv_ids]
+ vision_feats = [x[:, inv_ids] for x in vision_feats]
+ vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
+
+ return image, vision_feats, vision_pos_embeds, feat_sizes
+
+ def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
+ """
+ Prepare input mask, point or box prompts. Optionally, we allow tracking from
+ a custom `start_frame_idx` to the end of the video (for evaluation purposes).
+ """
+ # Load the ground-truth masks on all frames (so that we can later
+ # sample correction points from them)
+ # gt_masks_per_frame = {
+ # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
+ # for stage_id, targets in enumerate(input.find_targets)
+ # }
+ gt_masks_per_frame = {
+ stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im]
+ for stage_id, masks in enumerate(input.masks)
+ }
+ # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
+ backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
+ num_frames = input.num_frames
+ backbone_out["num_frames"] = num_frames
+
+ # Randomly decide whether to use point inputs or mask inputs
+ if self.training:
+ prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
+ prob_to_use_box_input = self.prob_to_use_box_input_for_train
+ num_frames_to_correct = self.num_frames_to_correct_for_train
+ rand_frames_to_correct = self.rand_frames_to_correct_for_train
+ num_init_cond_frames = self.num_init_cond_frames_for_train
+ rand_init_cond_frames = self.rand_init_cond_frames_for_train
+ else:
+ prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
+ prob_to_use_box_input = self.prob_to_use_box_input_for_eval
+ num_frames_to_correct = self.num_frames_to_correct_for_eval
+ rand_frames_to_correct = self.rand_frames_to_correct_for_eval
+ num_init_cond_frames = self.num_init_cond_frames_for_eval
+ rand_init_cond_frames = self.rand_init_cond_frames_for_eval
+ if num_frames == 1:
+ # here we handle a special case for mixing video + SAM on image training,
+ # where we force using point input for the SAM task on static images
+ prob_to_use_pt_input = 1.0
+ num_frames_to_correct = 1
+ num_init_cond_frames = 1
+ assert num_init_cond_frames >= 1
+ # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
+ use_pt_input = self.rng.random() < prob_to_use_pt_input
+ if rand_init_cond_frames and num_init_cond_frames > 1:
+ # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
+ num_init_cond_frames = self.rng.integers(
+ 1, num_init_cond_frames, endpoint=True
+ )
+ if (
+ use_pt_input
+ and rand_frames_to_correct
+ and num_frames_to_correct > num_init_cond_frames
+ ):
+ # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
+ # correction clicks (only for the case of point input)
+ num_frames_to_correct = self.rng.integers(
+ num_init_cond_frames, num_frames_to_correct, endpoint=True
+ )
+ backbone_out["use_pt_input"] = use_pt_input
+
+ # Sample initial conditioning frames
+ if num_init_cond_frames == 1:
+ init_cond_frames = [start_frame_idx] # starting frame
+ else:
+ # starting frame + randomly selected remaining frames (without replacement)
+ init_cond_frames = [start_frame_idx] + self.rng.choice(
+ range(start_frame_idx + 1, num_frames),
+ num_init_cond_frames - 1,
+ replace=False,
+ ).tolist()
+ backbone_out["init_cond_frames"] = init_cond_frames
+ backbone_out["frames_not_in_init_cond"] = [
+ t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
+ ]
+ # Prepare mask or point inputs on initial conditioning frames
+ backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: }
+ backbone_out["point_inputs_per_frame"] = {} # {frame_idx: }
+ for t in init_cond_frames:
+ if not use_pt_input:
+ backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
+ else:
+ # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
+ use_box_input = self.rng.random() < prob_to_use_box_input
+ if use_box_input:
+ points, labels = sample_box_points(
+ gt_masks_per_frame[t],
+ )
+ else:
+ # (here we only sample **one initial point** on initial conditioning frames from the
+ # ground-truth mask; we may sample more correction points on the fly)
+ points, labels = get_next_point(
+ gt_masks=gt_masks_per_frame[t],
+ pred_masks=None,
+ method=(
+ "uniform" if self.training else self.pt_sampling_for_eval
+ ),
+ )
+
+ point_inputs = {"point_coords": points, "point_labels": labels}
+ backbone_out["point_inputs_per_frame"][t] = point_inputs
+
+ # Sample frames where we will add correction clicks on the fly
+ # based on the error between prediction and ground-truth masks
+ if not use_pt_input:
+ # no correction points will be sampled when using mask inputs
+ frames_to_add_correction_pt = []
+ elif num_frames_to_correct == num_init_cond_frames:
+ frames_to_add_correction_pt = init_cond_frames
+ else:
+ assert num_frames_to_correct > num_init_cond_frames
+ # initial cond frame + randomly selected remaining frames (without replacement)
+ extra_num = num_frames_to_correct - num_init_cond_frames
+ frames_to_add_correction_pt = (
+ init_cond_frames
+ + self.rng.choice(
+ backbone_out["frames_not_in_init_cond"], extra_num, replace=False
+ ).tolist()
+ )
+ backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
+
+ return backbone_out
+
+ def forward_tracking(
+ self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
+ ):
+ """Forward video tracking on each frame (and sample correction clicks)."""
+ img_feats_already_computed = backbone_out["backbone_fpn"] is not None
+ if img_feats_already_computed:
+ # Prepare the backbone features
+ # - vision_feats and vision_pos_embeds are in (HW)BC format
+ (
+ _,
+ vision_feats,
+ vision_pos_embeds,
+ feat_sizes,
+ ) = self._prepare_backbone_features(backbone_out)
+
+ # Starting the stage loop
+ num_frames = backbone_out["num_frames"]
+ init_cond_frames = backbone_out["init_cond_frames"]
+ frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
+ # first process all the initial conditioning frames to encode them as memory,
+ # and then conditioning on them to track the remaining frames
+ processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
+ output_dict = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ for stage_id in processing_order:
+ # Get the image features for the current frames
+ # img_ids = input.find_inputs[stage_id].img_ids
+ img_ids = input.flat_obj_to_img_idx[stage_id]
+ if img_feats_already_computed:
+ # Retrieve image features according to img_ids (if they are already computed).
+ current_vision_feats = [x[:, img_ids] for x in vision_feats]
+ current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
+ else:
+ # Otherwise, compute the image features on the fly for the given img_ids
+ # (this might be used for evaluation on long videos to avoid backbone OOM).
+ (
+ _,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ ) = self._prepare_backbone_features_per_frame(
+ input.flat_img_batch, img_ids
+ )
+
+ # Get output masks based on this frame's prompts and previous memory
+ current_out = self.track_step(
+ frame_idx=stage_id,
+ is_init_cond_frame=stage_id in init_cond_frames,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
+ mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
+ gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
+ frames_to_add_correction_pt=frames_to_add_correction_pt,
+ output_dict=output_dict,
+ num_frames=num_frames,
+ )
+ # Append the output, depending on whether it's a conditioning frame
+ add_output_as_cond_frame = stage_id in init_cond_frames or (
+ self.add_all_frames_to_correct_as_cond
+ and stage_id in frames_to_add_correction_pt
+ )
+ if add_output_as_cond_frame:
+ output_dict["cond_frame_outputs"][stage_id] = current_out
+ else:
+ output_dict["non_cond_frame_outputs"][stage_id] = current_out
+
+ if return_dict:
+ return output_dict
+ # turn `output_dict` into a list for loss function
+ all_frame_outputs = {}
+ all_frame_outputs.update(output_dict["cond_frame_outputs"])
+ all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
+ all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
+ # Make DDP happy with activation checkpointing by removing unused keys
+ all_frame_outputs = [
+ {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
+ ]
+
+ return all_frame_outputs
+
+ def track_step(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
+ prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
+ frames_to_add_correction_pt=None,
+ gt_masks=None,
+ ):
+ if frames_to_add_correction_pt is None:
+ frames_to_add_correction_pt = []
+ current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse,
+ prev_sam_mask_logits,
+ )
+
+ (
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ ) = sam_outputs
+
+ current_out["multistep_pred_masks"] = low_res_masks
+ current_out["multistep_pred_masks_high_res"] = high_res_masks
+ current_out["multistep_pred_multimasks"] = [low_res_multimasks]
+ current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
+ current_out["multistep_pred_ious"] = [ious]
+ current_out["multistep_point_inputs"] = [point_inputs]
+ current_out["multistep_object_score_logits"] = [object_score_logits]
+
+ # Optionally, sample correction points iteratively to correct the mask
+ if frame_idx in frames_to_add_correction_pt:
+ point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
+ is_init_cond_frame,
+ point_inputs,
+ gt_masks,
+ high_res_features,
+ pix_feat,
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ object_score_logits,
+ current_out,
+ )
+ (
+ _,
+ _,
+ _,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ ) = final_sam_outputs
+
+ # Use the final prediction (after all correction steps for output and eval)
+ current_out["pred_masks"] = low_res_masks
+ current_out["pred_masks_high_res"] = high_res_masks
+ current_out["obj_ptr"] = obj_ptr
+
+ # Finally run the memory encoder on the predicted mask to encode
+ # it into a new memory feature (that can be used in future frames)
+ self._encode_memory_in_output(
+ current_vision_feats,
+ feat_sizes,
+ point_inputs,
+ run_mem_encoder,
+ high_res_masks,
+ object_score_logits,
+ current_out,
+ )
+ return current_out
+
+ def _iter_correct_pt_sampling(
+ self,
+ is_init_cond_frame,
+ point_inputs,
+ gt_masks,
+ high_res_features,
+ pix_feat_with_mem,
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ object_score_logits,
+ current_out,
+ ):
+
+ assert gt_masks is not None
+ all_pred_masks = [low_res_masks]
+ all_pred_high_res_masks = [high_res_masks]
+ all_pred_multimasks = [low_res_multimasks]
+ all_pred_high_res_multimasks = [high_res_multimasks]
+ all_pred_ious = [ious]
+ all_point_inputs = [point_inputs]
+ all_object_score_logits = [object_score_logits]
+ for _ in range(self.num_correction_pt_per_frame):
+ # sample a new point from the error between prediction and ground-truth
+ # (with a small probability, directly sample from GT masks instead of errors)
+ if self.training and self.prob_to_sample_from_gt_for_train > 0:
+ sample_from_gt = (
+ self.rng.random() < self.prob_to_sample_from_gt_for_train
+ )
+ else:
+ sample_from_gt = False
+ # if `pred_for_new_pt` is None, only GT masks will be used for point sampling
+ pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
+ new_points, new_labels = get_next_point(
+ gt_masks=gt_masks,
+ pred_masks=pred_for_new_pt,
+ method="uniform" if self.training else self.pt_sampling_for_eval,
+ )
+ point_inputs = concat_points(point_inputs, new_points, new_labels)
+ # Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
+ # For tracking, this means that when the user adds a correction click, we also feed
+ # the tracking output mask logits along with the click as input to the SAM decoder.
+ mask_inputs = low_res_masks
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
+ sam_outputs = torch.utils.checkpoint.checkpoint(
+ self._forward_sam_heads,
+ backbone_features=pix_feat_with_mem,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ high_res_features=high_res_features,
+ multimask_output=multimask_output,
+ use_reentrant=False,
+ )
+ else:
+ sam_outputs = self._forward_sam_heads(
+ backbone_features=pix_feat_with_mem,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ high_res_features=high_res_features,
+ multimask_output=multimask_output,
+ )
+ (
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ _,
+ object_score_logits,
+ ) = sam_outputs
+ all_pred_masks.append(low_res_masks)
+ all_pred_high_res_masks.append(high_res_masks)
+ all_pred_multimasks.append(low_res_multimasks)
+ all_pred_high_res_multimasks.append(high_res_multimasks)
+ all_pred_ious.append(ious)
+ all_point_inputs.append(point_inputs)
+ all_object_score_logits.append(object_score_logits)
+
+ # Concatenate the masks along channel (to compute losses on all of them,
+ # using `MultiStepIteractiveMasks`)
+ current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
+ current_out["multistep_pred_masks_high_res"] = torch.cat(
+ all_pred_high_res_masks, dim=1
+ )
+ current_out["multistep_pred_multimasks"] = all_pred_multimasks
+ current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
+ current_out["multistep_pred_ious"] = all_pred_ious
+ current_out["multistep_point_inputs"] = all_point_inputs
+ current_out["multistep_object_score_logits"] = all_object_score_logits
+
+ return point_inputs, sam_outputs
diff --git a/phantom/submodules/sam2/training/optimizer.py b/phantom/submodules/sam2/training/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae159663f6efc2dac4f5ffa3b1c91b97a78dec76
--- /dev/null
+++ b/phantom/submodules/sam2/training/optimizer.py
@@ -0,0 +1,502 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import fnmatch
+import inspect
+import itertools
+import logging
+import types
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ Union,
+)
+
+import hydra
+
+import torch
+import torch.nn as nn
+from omegaconf import DictConfig
+from torch import Tensor
+
+
+class Optimizer:
+ def __init__(self, optimizer, schedulers=None) -> None:
+ self.optimizer = optimizer
+ self.schedulers = schedulers
+ self._validate_optimizer_schedulers()
+ self.step_schedulers(0.0, 0)
+
+ def _validate_optimizer_schedulers(self):
+ if self.schedulers is None:
+ return
+ for _, set_of_schedulers in enumerate(self.schedulers):
+ for option, _ in set_of_schedulers.items():
+ assert option in self.optimizer.defaults, (
+ "Optimizer option "
+ f"{option} not found in {self.optimizer}. Valid options are "
+ f"{self.optimizer.defaults.keys()}"
+ )
+
+ def step_schedulers(self, where: float, step: int) -> None:
+ if self.schedulers is None:
+ return
+ for i, param_group in enumerate(self.optimizer.param_groups):
+ for option, scheduler in self.schedulers[i].items():
+ if "step" in inspect.signature(scheduler.__call__).parameters:
+ new_value = scheduler(step=step, where=where)
+ elif (
+ hasattr(scheduler, "scheduler")
+ and "step"
+ in inspect.signature(scheduler.scheduler.__call__).parameters
+ ):
+ # To handle ValueScaler wrappers
+ new_value = scheduler(step=step, where=where)
+ else:
+ new_value = scheduler(where)
+ param_group[option] = new_value
+
+ def step(self, where, step, closure=None):
+ self.step_schedulers(where, step)
+ return self.optimizer.step(closure)
+
+ def zero_grad(self, *args, **kwargs):
+ return self.optimizer.zero_grad(*args, **kwargs)
+
+
+def set_default_parameters(
+ scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
+) -> None:
+ """Set up the "default" scheduler with the right parameters.
+
+ Args:
+ scheduler_cgfs: A list of scheduler configs, where each scheduler also
+ specifies which parameters it applies to, based on the names of parameters
+ or the class of the modules. At most one scheduler is allowed to skip this
+ specification, which is used as a "default" specification for any remaining
+ parameters.
+ all_parameter_names: Names of all the parameters to consider.
+ """
+ constraints = [
+ scheduler_cfg.parameter_names
+ for scheduler_cfg in scheduler_cfgs
+ if scheduler_cfg.parameter_names is not None
+ ]
+ if len(constraints) == 0:
+ default_params = set(all_parameter_names)
+ else:
+ default_params = all_parameter_names - set.union(*constraints)
+ default_count = 0
+ for scheduler_cfg in scheduler_cfgs:
+ if scheduler_cfg.parameter_names is None:
+ scheduler_cfg.parameter_names = default_params
+ default_count += 1
+ assert default_count <= 1, "Only one scheduler per option can be default"
+ if default_count == 0:
+ # No default scheduler specified, add a default, but without any scheduler
+ # for that option
+ scheduler_cfgs.append({"parameter_names": default_params})
+
+
+def name_constraints_to_parameters(
+ param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
+) -> List[torch.nn.Parameter]:
+ """Return parameters which match the intersection of parameter constraints.
+
+ Note that this returns the parameters themselves, not their names.
+
+ Args:
+ param_constraints: A list, with each element being a set of allowed parameters.
+ named_parameters: Mapping from a parameter name to the parameter itself.
+
+ Returns:
+ A list containing the parameters which overlap with _each_ constraint set from
+ param_constraints.
+ """
+ matching_names = set.intersection(*param_constraints)
+ return [value for name, value in named_parameters.items() if name in matching_names]
+
+
+def map_scheduler_cfgs_to_param_groups(
+ all_scheduler_cfgs: Iterable[List[Dict]],
+ named_parameters: Dict[str, Tensor],
+) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
+ """Produce parameter groups corresponding to all the scheduler configs.
+
+ Takes all the scheduler configs, each of which applies to a specific optimizer
+ option (like "lr" or "weight_decay") and has a set of parameter names which it
+ applies to, and produces a final set of param groups where each param group
+ covers all the options which apply to a particular set of parameters.
+
+ Args:
+ all_scheduler_cfgs: All the scheduler configs covering every option.
+ named_parameters: Mapping from a parameter name to the parameter itself.
+ Returns:
+ Tuple of lists of schedulers and param_groups, where schedulers[i]
+ applies to param_groups[i].
+ """
+
+ scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
+ schedulers = []
+ param_groups = []
+ for scheduler_cfgs in scheduler_cfgs_per_param_group:
+ param_constraints = [
+ scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
+ ]
+ matching_parameters = name_constraints_to_parameters(
+ param_constraints, named_parameters
+ )
+ if len(matching_parameters) == 0: # If no overlap of parameters, skip
+ continue
+ schedulers_for_group = {
+ scheduler_cfg["option"]: scheduler_cfg["scheduler"]
+ for scheduler_cfg in scheduler_cfgs
+ if "option" in scheduler_cfg
+ }
+ schedulers.append(schedulers_for_group)
+ param_groups.append({"params": matching_parameters})
+ return schedulers, param_groups
+
+
+def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
+ """Check that the param groups are non-overlapping and cover all the parameters.
+
+ Args:
+ param_groups: List of all param groups
+ model: Model to validate against. The check ensures that all the model
+ parameters are part of param_groups
+ """
+ for pg in param_groups:
+ # no param should be repeated within a group
+ assert len(pg["params"]) == len(set(pg["params"]))
+ parameters = [set(param_group["params"]) for param_group in param_groups]
+ model_parameters = {parameter for _, parameter in model.named_parameters()}
+ for p1, p2 in itertools.permutations(parameters, 2):
+ assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
+ assert set.union(*parameters) == model_parameters, (
+ "Scheduler generated param_groups must include all parameters of the model."
+ f" Found {len(set.union(*parameters))} params whereas model has"
+ f" {len(model_parameters)} params"
+ )
+
+
+def unix_module_cls_pattern_to_parameter_names(
+ filter_module_cls_names: List[str],
+ module_cls_to_param_names: Dict[Type, str],
+) -> Union[None, Set[str]]:
+ """Returns param names which pass the filters specified in filter_module_cls_names.
+
+ Args:
+ filter_module_cls_names: A list of filter strings containing class names, like
+ ["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
+ module_cls_to_param_names: Mapping from module classes to the parameter names
+ they contain. See `get_module_cls_to_param_names`.
+ """
+ if filter_module_cls_names is None:
+ return set()
+ allowed_parameter_names = []
+ for module_cls_name in filter_module_cls_names:
+ module_cls = hydra.utils.get_class(module_cls_name)
+ if module_cls not in module_cls_to_param_names:
+ raise AssertionError(
+ f"module_cls_name {module_cls_name} does not "
+ "match any classes in the model"
+ )
+ matching_parameters = module_cls_to_param_names[module_cls]
+ assert (
+ len(matching_parameters) > 0
+ ), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
+ logging.info(
+ f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
+ )
+ allowed_parameter_names.append(matching_parameters)
+ return set.union(*allowed_parameter_names)
+
+
+def unix_param_pattern_to_parameter_names(
+ filter_param_names: Optional[List[str]],
+ parameter_names: Dict[str, torch.Tensor],
+) -> Union[None, Set[str]]:
+ """Returns param names which pass the filters specified in filter_param_names.
+
+ Args:
+ filter_param_names: A list of unix-style filter strings with optional
+ wildcards, like ["block.2.*", "block.2.linear.weight"]
+ module_cls_to_param_names: Mapping from module classes to the parameter names
+ they contain. See `get_module_cls_to_param_names`.
+ """
+
+ if filter_param_names is None:
+ return set()
+ allowed_parameter_names = []
+ for param_name in filter_param_names:
+ matching_parameters = set(fnmatch.filter(parameter_names, param_name))
+ assert (
+ len(matching_parameters) >= 1
+ ), f"param_name {param_name} does not match any parameters in the model"
+ logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
+ allowed_parameter_names.append(matching_parameters)
+ return set.union(*allowed_parameter_names)
+
+
+def _unix_pattern_to_parameter_names(
+ scheduler_cfg: DictConfig,
+ parameter_names: Set[str],
+ module_cls_to_param_names: Dict[Type, str],
+) -> Union[None, Set[str]]:
+ """Returns param names which pass the filters specified in scheduler_cfg.
+
+ Args:
+ scheduler_cfg: The config for the scheduler
+ parameter_names: The set of all parameter names which will be filtered
+ """
+ if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
+ return None
+ return unix_param_pattern_to_parameter_names(
+ scheduler_cfg.get("param_names"), parameter_names
+ ).union(
+ unix_module_cls_pattern_to_parameter_names(
+ scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
+ )
+ )
+
+
+def get_module_cls_to_param_names(
+ model: nn.Module, param_allowlist: Set[str] = None
+) -> Dict[Type, str]:
+ """Produce a mapping from all the modules classes to the names of parames they own.
+
+ Only counts a parameter as part of the immediate parent module, i.e. recursive
+ parents do not count.
+
+ Args:
+ model: Model to iterate over
+ param_allowlist: If specified, only these param names will be processed
+ """
+
+ module_cls_to_params = {}
+ for module_name, module in model.named_modules():
+ module_cls = type(module)
+ module_cls_to_params.setdefault(module_cls, set())
+ for param_name, _ in module.named_parameters(recurse=False):
+ full_param_name = get_full_parameter_name(module_name, param_name)
+ if param_allowlist is None or full_param_name in param_allowlist:
+ module_cls_to_params[module_cls].add(full_param_name)
+ return module_cls_to_params
+
+
+def construct_optimizer(
+ model: torch.nn.Module,
+ optimizer_conf: Any,
+ options_conf: Mapping[str, List] = None,
+ param_group_modifiers_conf: List[Callable] = None,
+ param_allowlist: Optional[Set[str]] = None,
+ validate_param_groups=True,
+) -> Optimizer:
+ """
+ Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
+ with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
+ Batchnorm and/or no-update 1-D parameters support, based on the config.
+
+ Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
+ (LARS): https://arxiv.org/abs/1708.03888
+
+ Args:
+ model: model to perform stochastic gradient descent
+ optimization or ADAM optimization.
+ optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
+ ADAM, still missing the params argument which this function provides to
+ produce the final optimizer
+ param_group_modifiers_conf: Optional user specified functions which can modify
+ the final scheduler configs before the optimizer's param groups are built
+ param_allowlist: The parameters to optimize. Parameters which are not part of
+ this allowlist will be skipped.
+ validate_param_groups: If enabled, valides that the produced param_groups don't
+ overlap and cover all the model parameters.
+ """
+ if param_allowlist is None:
+ param_allowlist = {name for name, _ in model.named_parameters()}
+
+ named_parameters = {
+ name: param
+ for name, param in model.named_parameters()
+ if name in param_allowlist
+ }
+
+ if not options_conf:
+ optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
+ return Optimizer(optimizer)
+
+ all_parameter_names = {
+ name for name, _ in model.named_parameters() if name in param_allowlist
+ }
+ module_cls_to_all_param_names = get_module_cls_to_param_names(
+ model, param_allowlist
+ )
+
+ scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
+ all_scheduler_cfgs = []
+ for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
+ for config in scheduler_cfgs:
+ config.option = option
+ config.parameter_names = _unix_pattern_to_parameter_names(
+ config, all_parameter_names, module_cls_to_all_param_names
+ )
+ set_default_parameters(scheduler_cfgs, all_parameter_names)
+ all_scheduler_cfgs.append(scheduler_cfgs)
+
+ if param_group_modifiers_conf:
+ for custom_param_modifier in param_group_modifiers_conf:
+ custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
+ all_scheduler_cfgs = custom_param_modifier(
+ scheduler_cfgs=all_scheduler_cfgs, model=model
+ )
+ schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
+ all_scheduler_cfgs, named_parameters
+ )
+ if validate_param_groups:
+ validate_param_group_params(param_groups, model)
+ optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
+ return Optimizer(optimizer, schedulers)
+
+
+def get_full_parameter_name(module_name, param_name):
+ if module_name == "":
+ return param_name
+ return f"{module_name}.{param_name}"
+
+
+class GradientClipper:
+ """
+ Gradient clipping utils that works for DDP
+ """
+
+ def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
+ assert isinstance(max_norm, (int, float)) or max_norm is None
+ self.max_norm = max_norm if max_norm is None else float(max_norm)
+ self.norm_type = norm_type
+
+ def __call__(self, model: nn.Module):
+ if self.max_norm is None:
+ return # no-op
+
+ nn.utils.clip_grad_norm_(
+ model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
+ )
+
+
+class ValueScaler:
+ def __init__(self, scheduler, mult_val: float):
+ self.scheduler = scheduler
+ self.mult_val = mult_val
+
+ def __call__(self, *args, **kwargs):
+ val = self.scheduler(*args, **kwargs)
+ return val * self.mult_val
+
+
+def rgetattr(obj, rattrs: str = None):
+ """
+ Like getattr(), but supports dotted notation for nested objects.
+ rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
+ """
+ if rattrs is None:
+ return obj
+ attrs = rattrs.split(".")
+ for attr in attrs:
+ obj = getattr(obj, attr)
+ return obj
+
+
+def layer_decay_param_modifier(
+ scheduler_cfgs: List[List[Dict]],
+ model,
+ layer_decay_value: float,
+ layer_decay_min: Optional[float] = None,
+ apply_to: Optional[str] = None,
+ overrides: List[Dict] = (),
+) -> List[List[Dict]]:
+ """
+ Args
+ - scheduler_cfgs: a list of omegaconf.ListConfigs.
+ Each element in the list is a omegaconfg.DictConfig with the following structure
+ {
+ "scheduler":
+ "option": possible options are "lr", "weight_decay" etc.
+ "parameter_names": Set of str indicating param names that this scheduler applies to
+ }
+ - model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
+ and a method get_num_layers.
+ Alternatively, use apply_to argument to select a specific component of the model.
+ - layer_decay_value: float
+ - layer_decay_min: min val for layer decay
+ - apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
+ - overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
+ Returns
+ - scheduler_configs: same structure as the input, elements can be modified
+ """
+ model = rgetattr(model, apply_to)
+ num_layers = model.get_num_layers() + 1
+ layer_decays = [
+ layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
+ ]
+ if layer_decay_min is not None:
+ layer_decays = [max(val, layer_decay_min) for val in layer_decays]
+ final_scheduler_cfgs = []
+ # scheduler_cfgs is a list of lists
+ for scheduler_cfg_group in scheduler_cfgs:
+ curr_cfg_group = []
+ # scheduler_cfg_group is a list of dictionaries
+ for scheduler_cfg in scheduler_cfg_group:
+ if scheduler_cfg["option"] != "lr":
+ curr_cfg_group.append(scheduler_cfg)
+ continue
+ # Need sorted so that the list of parameter names is deterministic and consistent
+ # across re-runs of this job. Else it was causing issues with loading the optimizer
+ # state during a job restart (D38591759)
+ parameter_names = sorted(scheduler_cfg["parameter_names"])
+
+ # Only want one cfg group per layer
+ layer_cfg_groups = {}
+ for param_name in parameter_names:
+ layer_id = num_layers
+ this_scale = layer_decays[layer_id]
+ if param_name.startswith(apply_to):
+ layer_id = model.get_layer_id(param_name)
+ this_scale = layer_decays[layer_id]
+ # Overrides
+ for override in overrides:
+ if fnmatch.fnmatchcase(param_name, override["pattern"]):
+ this_scale = float(override["value"])
+ layer_id = override["pattern"]
+ break
+
+ if layer_id not in layer_cfg_groups:
+ curr_param = {
+ "option": scheduler_cfg["option"],
+ "scheduler": ValueScaler(
+ scheduler_cfg["scheduler"], this_scale
+ ),
+ "parameter_names": {param_name},
+ }
+ else:
+ curr_param = layer_cfg_groups[layer_id]
+ curr_param["parameter_names"].add(param_name)
+ layer_cfg_groups[layer_id] = curr_param
+
+ for layer_cfg in layer_cfg_groups.values():
+ curr_cfg_group.append(layer_cfg)
+
+ final_scheduler_cfgs.append(curr_cfg_group)
+ return final_scheduler_cfgs
diff --git a/phantom/submodules/sam2/training/scripts/sav_frame_extraction_submitit.py b/phantom/submodules/sam2/training/scripts/sav_frame_extraction_submitit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d5ed2fc77deecf87c8d823bb3fdcf3cb856fc94
--- /dev/null
+++ b/phantom/submodules/sam2/training/scripts/sav_frame_extraction_submitit.py
@@ -0,0 +1,163 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+import argparse
+import os
+from pathlib import Path
+
+import cv2
+
+import numpy as np
+import submitit
+import tqdm
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser(
+ description="[SA-V Preprocessing] Extracting JPEG frames",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # ------------
+ # DATA
+ # ------------
+ data_parser = parser.add_argument_group(
+ title="SA-V dataset data root",
+ description="What data to load and how to process it.",
+ )
+ data_parser.add_argument(
+ "--sav-vid-dir",
+ type=str,
+ required=True,
+ help=("Where to find the SAV videos"),
+ )
+ data_parser.add_argument(
+ "--sav-frame-sample-rate",
+ type=int,
+ default=4,
+ help="Rate at which to sub-sample frames",
+ )
+
+ # ------------
+ # LAUNCH
+ # ------------
+ launch_parser = parser.add_argument_group(
+ title="Cluster launch settings",
+ description="Number of jobs and retry settings.",
+ )
+ launch_parser.add_argument(
+ "--n-jobs",
+ type=int,
+ required=True,
+ help="Shard the run over this many jobs.",
+ )
+ launch_parser.add_argument(
+ "--timeout", type=int, required=True, help="SLURM timeout parameter in minutes."
+ )
+ launch_parser.add_argument(
+ "--partition", type=str, required=True, help="Partition to launch on."
+ )
+ launch_parser.add_argument(
+ "--account", type=str, required=True, help="Partition to launch on."
+ )
+ launch_parser.add_argument("--qos", type=str, required=True, help="QOS.")
+
+ # ------------
+ # OUTPUT
+ # ------------
+ output_parser = parser.add_argument_group(
+ title="Setting for results output", description="Where and how to save results."
+ )
+ output_parser.add_argument(
+ "--output-dir",
+ type=str,
+ required=True,
+ help=("Where to dump the extracted jpeg frames"),
+ )
+ output_parser.add_argument(
+ "--slurm-output-root-dir",
+ type=str,
+ required=True,
+ help=("Where to save slurm outputs"),
+ )
+ return parser
+
+
+def decode_video(video_path: str):
+ assert os.path.exists(video_path)
+ video = cv2.VideoCapture(video_path)
+ video_frames = []
+ while video.isOpened():
+ ret, frame = video.read()
+ if ret:
+ video_frames.append(frame)
+ else:
+ break
+ return video_frames
+
+
+def extract_frames(video_path, sample_rate):
+ frames = decode_video(video_path)
+ return frames[::sample_rate]
+
+
+def submitit_launch(video_paths, sample_rate, save_root):
+ for path in tqdm.tqdm(video_paths):
+ frames = extract_frames(path, sample_rate)
+ output_folder = os.path.join(save_root, Path(path).stem)
+ if not os.path.exists(output_folder):
+ os.makedirs(output_folder)
+ for fid, frame in enumerate(frames):
+ frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg")
+ cv2.imwrite(frame_path, frame)
+ print(f"Saved output to {save_root}")
+
+
+if __name__ == "__main__":
+ parser = get_args_parser()
+ args = parser.parse_args()
+
+ sav_vid_dir = args.sav_vid_dir
+ save_root = args.output_dir
+ sample_rate = args.sav_frame_sample_rate
+
+ # List all SA-V videos
+ mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")])
+ mp4_files = np.array(mp4_files)
+ chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)]
+
+ print(f"Processing videos in: {sav_vid_dir}")
+ print(f"Processing {len(mp4_files)} files")
+ print(f"Beginning processing in {args.n_jobs} processes")
+
+ # Submitit params
+ jobs_dir = os.path.join(args.slurm_output_root_dir, "%j")
+ cpus_per_task = 4
+ executor = submitit.AutoExecutor(folder=jobs_dir)
+ executor.update_parameters(
+ timeout_min=args.timeout,
+ gpus_per_node=0,
+ tasks_per_node=1,
+ slurm_array_parallelism=args.n_jobs,
+ cpus_per_task=cpus_per_task,
+ slurm_partition=args.partition,
+ slurm_account=args.account,
+ slurm_qos=args.qos,
+ )
+ executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"])
+
+ # Launch
+ jobs = []
+ with executor.batch():
+ for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)):
+ job = executor.submit(
+ submitit_launch,
+ video_paths=mp4_chunk,
+ sample_rate=sample_rate,
+ save_root=save_root,
+ )
+ jobs.append(job)
+
+ for j in jobs:
+ print(f"Slurm JobID: {j.job_id}")
+ print(f"Saving outputs to {save_root}")
+ print(f"Slurm outputs at {args.slurm_output_root_dir}")
diff --git a/phantom/submodules/sam2/training/train.py b/phantom/submodules/sam2/training/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..db06123fcb1b2ba8ff5f462dbb7411d42a57c9a0
--- /dev/null
+++ b/phantom/submodules/sam2/training/train.py
@@ -0,0 +1,270 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import random
+import sys
+import traceback
+from argparse import ArgumentParser
+
+import submitit
+import torch
+
+from hydra import compose, initialize_config_module
+from hydra.utils import instantiate
+
+from iopath.common.file_io import g_pathmgr
+from omegaconf import OmegaConf
+
+from training.utils.train_utils import makedir, register_omegaconf_resolvers
+
+os.environ["HYDRA_FULL_ERROR"] = "1"
+
+
+def single_proc_run(local_rank, main_port, cfg, world_size):
+ """Single GPU process"""
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(main_port)
+ os.environ["RANK"] = str(local_rank)
+ os.environ["LOCAL_RANK"] = str(local_rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+ try:
+ register_omegaconf_resolvers()
+ except Exception as e:
+ logging.info(e)
+
+ trainer = instantiate(cfg.trainer, _recursive_=False)
+ trainer.run()
+
+
+def single_node_runner(cfg, main_port: int):
+ assert cfg.launcher.num_nodes == 1
+ num_proc = cfg.launcher.gpus_per_node
+ torch.multiprocessing.set_start_method(
+ "spawn"
+ ) # CUDA runtime does not support `fork`
+ if num_proc == 1:
+ # directly call single_proc so we can easily set breakpoints
+ # mp.spawn does not let us set breakpoints
+ single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)
+ else:
+ mp_runner = torch.multiprocessing.start_processes
+ args = (main_port, cfg, num_proc)
+ # Note: using "fork" below, "spawn" causes time and error regressions. Using
+ # spawn changes the default multiprocessing context to spawn, which doesn't
+ # interact well with the dataloaders (likely due to the use of OpenCV).
+ mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
+
+
+def format_exception(e: Exception, limit=20):
+ traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
+ return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"
+
+
+class SubmititRunner(submitit.helpers.Checkpointable):
+ """A callable which is passed to submitit to launch the jobs."""
+
+ def __init__(self, port, cfg):
+ self.cfg = cfg
+ self.port = port
+ self.has_setup = False
+
+ def run_trainer(self):
+ job_env = submitit.JobEnvironment()
+ # Need to add this again so the hydra.job.set_env PYTHONPATH
+ # is also set when launching jobs.
+ add_pythonpath_to_sys_path()
+ os.environ["MASTER_ADDR"] = job_env.hostnames[0]
+ os.environ["MASTER_PORT"] = str(self.port)
+ os.environ["RANK"] = str(job_env.global_rank)
+ os.environ["LOCAL_RANK"] = str(job_env.local_rank)
+ os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
+
+ register_omegaconf_resolvers()
+ cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)
+ cfg_resolved = OmegaConf.create(cfg_resolved)
+
+ trainer = instantiate(cfg_resolved.trainer, _recursive_=False)
+ trainer.run()
+
+ def __call__(self):
+ job_env = submitit.JobEnvironment()
+ self.setup_job_info(job_env.job_id, job_env.global_rank)
+ try:
+ self.run_trainer()
+ except Exception as e:
+ # Log the exception. Then raise it again (as what SubmititRunner currently does).
+ message = format_exception(e)
+ logging.error(message)
+ raise e
+
+ def setup_job_info(self, job_id, rank):
+ """Set up slurm job info"""
+ self.job_info = {
+ "job_id": job_id,
+ "rank": rank,
+ "cluster": self.cfg.get("cluster", None),
+ "experiment_log_dir": self.cfg.launcher.experiment_log_dir,
+ }
+
+ self.has_setup = True
+
+
+def add_pythonpath_to_sys_path():
+ if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
+ return
+ sys.path = os.environ["PYTHONPATH"].split(":") + sys.path
+
+
+def main(args) -> None:
+ cfg = compose(config_name=args.config)
+ if cfg.launcher.experiment_log_dir is None:
+ cfg.launcher.experiment_log_dir = os.path.join(
+ os.getcwd(), "sam2_logs", args.config
+ )
+ print("###################### Train App Config ####################")
+ print(OmegaConf.to_yaml(cfg))
+ print("############################################################")
+
+ add_pythonpath_to_sys_path()
+ makedir(cfg.launcher.experiment_log_dir)
+ with g_pathmgr.open(
+ os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
+ ) as f:
+ f.write(OmegaConf.to_yaml(cfg))
+
+ cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
+ cfg_resolved = OmegaConf.create(cfg_resolved)
+
+ with g_pathmgr.open(
+ os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w"
+ ) as f:
+ f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True))
+
+ submitit_conf = cfg.get("submitit", None)
+ assert submitit_conf is not None, "Missing submitit config"
+
+ submitit_dir = cfg.launcher.experiment_log_dir
+ submitit_dir = os.path.join(submitit_dir, "submitit_logs")
+ # Priotrize cmd line args
+ cfg.launcher.gpus_per_node = (
+ args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node
+ )
+ cfg.launcher.num_nodes = (
+ args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes
+ )
+ submitit_conf.use_cluster = (
+ args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster
+ )
+ if submitit_conf.use_cluster:
+ executor = submitit.AutoExecutor(folder=submitit_dir)
+ submitit_conf.partition = (
+ args.partition
+ if args.partition is not None
+ else submitit_conf.get("partition", None)
+ )
+ submitit_conf.account = (
+ args.account
+ if args.account is not None
+ else submitit_conf.get("account", None)
+ )
+ submitit_conf.qos = (
+ args.qos if args.qos is not None else submitit_conf.get("qos", None)
+ )
+ job_kwargs = {
+ "timeout_min": 60 * submitit_conf.timeout_hour,
+ "name": (
+ submitit_conf.name if hasattr(submitit_conf, "name") else args.config
+ ),
+ "slurm_partition": submitit_conf.partition,
+ "gpus_per_node": cfg.launcher.gpus_per_node,
+ "tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU
+ "cpus_per_task": submitit_conf.cpus_per_task,
+ "nodes": cfg.launcher.num_nodes,
+ "slurm_additional_parameters": {
+ "exclude": " ".join(submitit_conf.get("exclude_nodes", [])),
+ },
+ }
+ if "include_nodes" in submitit_conf:
+ assert (
+ len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
+ ), "Not enough nodes"
+ job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
+ submitit_conf["include_nodes"]
+ )
+ if submitit_conf.account is not None:
+ job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account
+ if submitit_conf.qos is not None:
+ job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos
+
+ if submitit_conf.get("mem_gb", None) is not None:
+ job_kwargs["mem_gb"] = submitit_conf.mem_gb
+ elif submitit_conf.get("mem", None) is not None:
+ job_kwargs["slurm_mem"] = submitit_conf.mem
+
+ if submitit_conf.get("constraints", None) is not None:
+ job_kwargs["slurm_constraint"] = submitit_conf.constraints
+
+ if submitit_conf.get("comment", None) is not None:
+ job_kwargs["slurm_comment"] = submitit_conf.comment
+
+ # Supports only cpu-bind option within srun_args. New options can be added here
+ if submitit_conf.get("srun_args", None) is not None:
+ job_kwargs["slurm_srun_args"] = []
+ if submitit_conf.srun_args.get("cpu_bind", None) is not None:
+ job_kwargs["slurm_srun_args"].extend(
+ ["--cpu-bind", submitit_conf.srun_args.cpu_bind]
+ )
+
+ print("###################### SLURM Config ####################")
+ print(job_kwargs)
+ print("##########################################")
+ executor.update_parameters(**job_kwargs)
+
+ main_port = random.randint(
+ submitit_conf.port_range[0], submitit_conf.port_range[1]
+ )
+ runner = SubmititRunner(main_port, cfg)
+ job = executor.submit(runner)
+ print(f"Submitit Job ID: {job.job_id}")
+ runner.setup_job_info(job.job_id, rank=0)
+ else:
+ cfg.launcher.num_nodes = 1
+ main_port = random.randint(
+ submitit_conf.port_range[0], submitit_conf.port_range[1]
+ )
+ single_node_runner(cfg, main_port)
+
+
+if __name__ == "__main__":
+
+ initialize_config_module("sam2", version_base="1.2")
+ parser = ArgumentParser()
+ parser.add_argument(
+ "-c",
+ "--config",
+ required=True,
+ type=str,
+ help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)",
+ )
+ parser.add_argument(
+ "--use-cluster",
+ type=int,
+ default=None,
+ help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
+ )
+ parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
+ parser.add_argument("--account", type=str, default=None, help="SLURM account")
+ parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
+ parser.add_argument(
+ "--num-gpus", type=int, default=None, help="number of GPUS per node"
+ )
+ parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")
+ args = parser.parse_args()
+ args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None
+ register_omegaconf_resolvers()
+ main(args)
diff --git a/phantom/submodules/sam2/training/trainer.py b/phantom/submodules/sam2/training/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b7c27b5145e2c03848331345ac246296accbc1d
--- /dev/null
+++ b/phantom/submodules/sam2/training/trainer.py
@@ -0,0 +1,1113 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import gc
+import json
+import logging
+import math
+import os
+import time
+from collections import OrderedDict
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Mapping, Optional
+
+import numpy as np
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from hydra.utils import instantiate
+from iopath.common.file_io import g_pathmgr
+
+from training.optimizer import construct_optimizer
+
+from training.utils.checkpoint_utils import (
+ assert_skipped_parameters_are_frozen,
+ exclude_params_matching_unix_pattern,
+ load_state_dict_into_model,
+ with_check_parameter_frozen,
+)
+from training.utils.data_utils import BatchedVideoDatapoint
+from training.utils.distributed import all_reduce_max, barrier, get_rank
+
+from training.utils.logger import Logger, setup_logging
+
+from training.utils.train_utils import (
+ AverageMeter,
+ collect_dict_keys,
+ DurationMeter,
+ get_amp_type,
+ get_machine_local_and_dist_rank,
+ get_resume_checkpoint,
+ human_readable_time,
+ is_dist_avail_and_initialized,
+ log_env_variables,
+ makedir,
+ MemMeter,
+ Phase,
+ ProgressMeter,
+ set_seeds,
+ setup_distributed_backend,
+)
+
+
+CORE_LOSS_KEY = "core_loss"
+
+
+def unwrap_ddp_if_wrapped(model):
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+ return model.module
+ return model
+
+
+@dataclass
+class OptimAMPConf:
+ enabled: bool = False
+ amp_dtype: str = "float16"
+
+
+@dataclass
+class OptimConf:
+ optimizer: torch.optim.Optimizer = None
+ options: Optional[Dict[str, Any]] = None
+ param_group_modifiers: Optional[List] = None
+ amp: Optional[Dict[str, Any]] = None
+ gradient_clip: Any = None
+ gradient_logger: Any = None
+
+ def __post_init__(self):
+ # amp
+ if not isinstance(self.amp, OptimAMPConf):
+ if self.amp is None:
+ self.amp = {}
+ assert isinstance(self.amp, Mapping)
+ self.amp = OptimAMPConf(**self.amp)
+
+
+@dataclass
+class DistributedConf:
+ backend: Optional[str] = None # inferred from accelerator type
+ comms_dtype: Optional[str] = None
+ find_unused_parameters: bool = False
+ timeout_mins: int = 30
+
+
+@dataclass
+class CudaConf:
+ cudnn_deterministic: bool = False
+ cudnn_benchmark: bool = True
+ allow_tf32: bool = False
+ # if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul
+ matmul_allow_tf32: Optional[bool] = None
+ # if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn
+ cudnn_allow_tf32: Optional[bool] = None
+
+
+@dataclass
+class CheckpointConf:
+ save_dir: str
+ save_freq: int
+ save_list: List[int] = field(default_factory=list)
+ model_weight_initializer: Any = None
+ save_best_meters: List[str] = None
+ skip_saving_parameters: List[str] = field(default_factory=list)
+ initialize_after_preemption: Optional[bool] = None
+ # if not None, training will be resumed from this checkpoint
+ resume_from: Optional[str] = None
+
+ def infer_missing(self):
+ if self.initialize_after_preemption is None:
+ with_skip_saving = len(self.skip_saving_parameters) > 0
+ self.initialize_after_preemption = with_skip_saving
+ return self
+
+
+@dataclass
+class LoggingConf:
+ log_dir: str
+ log_freq: int # In iterations
+ tensorboard_writer: Any
+ log_level_primary: str = "INFO"
+ log_level_secondary: str = "ERROR"
+ log_scalar_frequency: int = 100
+ log_visual_frequency: int = 100
+ scalar_keys_to_log: Optional[Dict[str, Any]] = None
+ log_batch_stats: bool = False
+
+
+class Trainer:
+ """
+ Trainer supporting the DDP training strategies.
+ """
+
+ EPSILON = 1e-8
+
+ def __init__(
+ self,
+ *, # the order of these args can change at any time, so they are keyword-only
+ data: Dict[str, Any],
+ model: Dict[str, Any],
+ logging: Dict[str, Any],
+ checkpoint: Dict[str, Any],
+ max_epochs: int,
+ mode: str = "train",
+ accelerator: str = "cuda",
+ seed_value: int = 123,
+ val_epoch_freq: int = 1,
+ distributed: Dict[str, bool] = None,
+ cuda: Dict[str, bool] = None,
+ env_variables: Optional[Dict[str, Any]] = None,
+ optim: Optional[Dict[str, Any]] = None,
+ optim_overrides: Optional[List[Dict[str, Any]]] = None,
+ meters: Optional[Dict[str, Any]] = None,
+ loss: Optional[Dict[str, Any]] = None,
+ ):
+
+ self._setup_env_variables(env_variables)
+ self._setup_timers()
+
+ self.data_conf = data
+ self.model_conf = model
+ self.logging_conf = LoggingConf(**logging)
+ self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing()
+ self.max_epochs = max_epochs
+ self.mode = mode
+ self.val_epoch_freq = val_epoch_freq
+ self.optim_conf = OptimConf(**optim) if optim is not None else None
+ self.meters_conf = meters
+ self.loss_conf = loss
+ distributed = DistributedConf(**distributed or {})
+ cuda = CudaConf(**cuda or {})
+ self.where = 0.0
+
+ self._infer_distributed_backend_if_none(distributed, accelerator)
+
+ self._setup_device(accelerator)
+
+ self._setup_torch_dist_and_backend(cuda, distributed)
+
+ makedir(self.logging_conf.log_dir)
+ setup_logging(
+ __name__,
+ output_dir=self.logging_conf.log_dir,
+ rank=self.rank,
+ log_level_primary=self.logging_conf.log_level_primary,
+ log_level_secondary=self.logging_conf.log_level_secondary,
+ )
+
+ set_seeds(seed_value, self.max_epochs, self.distributed_rank)
+ log_env_variables()
+
+ assert (
+ is_dist_avail_and_initialized()
+ ), "Torch distributed needs to be initialized before calling the trainer."
+
+ self._setup_components() # Except Optimizer everything is setup here.
+ self._move_to_device()
+ self._construct_optimizers()
+ self._setup_dataloaders()
+
+ self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
+
+ if self.checkpoint_conf.resume_from is not None:
+ assert os.path.exists(
+ self.checkpoint_conf.resume_from
+ ), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
+ dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
+ if self.distributed_rank == 0 and not os.path.exists(dst):
+ # Copy the "resume_from" checkpoint to the checkpoint folder
+ # if there is not a checkpoint to resume from already there
+ makedir(self.checkpoint_conf.save_dir)
+ g_pathmgr.copy(self.checkpoint_conf.resume_from, dst)
+ barrier()
+
+ self.load_checkpoint()
+ self._setup_ddp_distributed_training(distributed, accelerator)
+ barrier()
+
+ def _setup_timers(self):
+ """
+ Initializes counters for elapsed time and eta.
+ """
+ self.start_time = time.time()
+ self.ckpt_time_elapsed = 0
+ self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0)
+
+ def _get_meters(self, phase_filters=None):
+ if self.meters is None:
+ return {}
+ meters = {}
+ for phase, phase_meters in self.meters.items():
+ if phase_filters is not None and phase not in phase_filters:
+ continue
+ for key, key_meters in phase_meters.items():
+ if key_meters is None:
+ continue
+ for name, meter in key_meters.items():
+ meters[f"{phase}_{key}/{name}"] = meter
+ return meters
+
+ def _infer_distributed_backend_if_none(self, distributed_conf, accelerator):
+ if distributed_conf.backend is None:
+ distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo"
+
+ def _setup_env_variables(self, env_variables_conf) -> None:
+ if env_variables_conf is not None:
+ for variable_name, value in env_variables_conf.items():
+ os.environ[variable_name] = value
+
+ def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None:
+ if torch.cuda.is_available():
+ torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic
+ torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark
+ torch.backends.cuda.matmul.allow_tf32 = (
+ cuda_conf.matmul_allow_tf32
+ if cuda_conf.matmul_allow_tf32 is not None
+ else cuda_conf.allow_tf32
+ )
+ torch.backends.cudnn.allow_tf32 = (
+ cuda_conf.cudnn_allow_tf32
+ if cuda_conf.cudnn_allow_tf32 is not None
+ else cuda_conf.allow_tf32
+ )
+
+ self.rank = setup_distributed_backend(
+ distributed_conf.backend, distributed_conf.timeout_mins
+ )
+
+ def _setup_device(self, accelerator):
+ self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank()
+ if accelerator == "cuda":
+ self.device = torch.device("cuda", self.local_rank)
+ torch.cuda.set_device(self.local_rank)
+ elif accelerator == "cpu":
+ self.device = torch.device("cpu")
+ else:
+ raise ValueError(f"Unsupported accelerator: {accelerator}")
+
+ def _setup_ddp_distributed_training(self, distributed_conf, accelerator):
+
+ assert isinstance(self.model, torch.nn.Module)
+
+ self.model = nn.parallel.DistributedDataParallel(
+ self.model,
+ device_ids=[self.local_rank] if accelerator == "cuda" else [],
+ find_unused_parameters=distributed_conf.find_unused_parameters,
+ )
+ if distributed_conf.comms_dtype is not None: # noqa
+ from torch.distributed.algorithms import ddp_comm_hooks
+
+ amp_type = get_amp_type(distributed_conf.comms_dtype)
+ if amp_type == torch.bfloat16:
+ hook = ddp_comm_hooks.default_hooks.bf16_compress_hook
+ logging.info("Enabling bfloat16 grad communication")
+ else:
+ hook = ddp_comm_hooks.default_hooks.fp16_compress_hook
+ logging.info("Enabling fp16 grad communication")
+ process_group = None
+ self.model.register_comm_hook(process_group, hook)
+
+ def _move_to_device(self):
+ logging.info(
+ f"Moving components to device {self.device} and local rank {self.local_rank}."
+ )
+
+ self.model.to(self.device)
+
+ logging.info(
+ f"Done moving components to device {self.device} and local rank {self.local_rank}."
+ )
+
+ def save_checkpoint(self, epoch, checkpoint_names=None):
+ checkpoint_folder = self.checkpoint_conf.save_dir
+ makedir(checkpoint_folder)
+ if checkpoint_names is None:
+ checkpoint_names = ["checkpoint"]
+ if (
+ self.checkpoint_conf.save_freq > 0
+ and (int(epoch) % self.checkpoint_conf.save_freq == 0)
+ ) or int(epoch) in self.checkpoint_conf.save_list:
+ checkpoint_names.append(f"checkpoint_{int(epoch)}")
+
+ checkpoint_paths = []
+ for ckpt_name in checkpoint_names:
+ checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt"))
+
+ state_dict = unwrap_ddp_if_wrapped(self.model).state_dict()
+ state_dict = exclude_params_matching_unix_pattern(
+ patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict
+ )
+
+ checkpoint = {
+ "model": state_dict,
+ "optimizer": self.optim.optimizer.state_dict(),
+ "epoch": epoch,
+ "loss": self.loss.state_dict(),
+ "steps": self.steps,
+ "time_elapsed": self.time_elapsed_meter.val,
+ "best_meter_values": self.best_meter_values,
+ }
+ if self.optim_conf.amp.enabled:
+ checkpoint["scaler"] = self.scaler.state_dict()
+
+ # DDP checkpoints are only saved on rank 0 (all workers are identical)
+ if self.distributed_rank != 0:
+ return
+
+ for checkpoint_path in checkpoint_paths:
+ self._save_checkpoint(checkpoint, checkpoint_path)
+
+ def _save_checkpoint(self, checkpoint, checkpoint_path):
+ """
+ Save a checkpoint while guarding against the job being killed in the middle
+ of checkpoint saving (which corrupts the checkpoint file and ruins the
+ entire training since usually only the last checkpoint is kept per run).
+
+ We first save the new checkpoint to a temp file (with a '.tmp' suffix), and
+ and move it to overwrite the old checkpoint_path.
+ """
+ checkpoint_path_tmp = f"{checkpoint_path}.tmp"
+ with g_pathmgr.open(checkpoint_path_tmp, "wb") as f:
+ torch.save(checkpoint, f)
+ # after torch.save is completed, replace the old checkpoint with the new one
+ if g_pathmgr.exists(checkpoint_path):
+ # remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails)
+ g_pathmgr.rm(checkpoint_path)
+ success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path)
+ assert success
+
+ def load_checkpoint(self):
+ ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir)
+ if ckpt_path is None:
+ self._init_model_state()
+ else:
+ if self.checkpoint_conf.initialize_after_preemption:
+ self._call_model_initializer()
+ self._load_resuming_checkpoint(ckpt_path)
+
+ def _init_model_state(self):
+ # Checking that parameters that won't be saved are indeed frozen
+ # We do this check here before even saving the model to catch errors
+ # are early as possible and not at the end of the first epoch
+ assert_skipped_parameters_are_frozen(
+ patterns=self.checkpoint_conf.skip_saving_parameters,
+ model=self.model,
+ )
+
+ # Checking that parameters that won't be saved are initialized from
+ # within the model definition, unless `initialize_after_preemption`
+ # is explicitly set to `True`. If not, this is a bug, and after
+ # preemption, the `skip_saving_parameters` will have random values
+ allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption
+ with with_check_parameter_frozen(
+ patterns=self.checkpoint_conf.skip_saving_parameters,
+ model=self.model,
+ disabled=allow_init_skip_parameters,
+ ):
+ self._call_model_initializer()
+
+ def _call_model_initializer(self):
+ model_weight_initializer = instantiate(
+ self.checkpoint_conf.model_weight_initializer
+ )
+ if model_weight_initializer is not None:
+ logging.info(
+ f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}"
+ )
+ self.model = model_weight_initializer(model=self.model)
+
+ def _load_resuming_checkpoint(self, ckpt_path: str):
+ logging.info(f"Resuming training from {ckpt_path}")
+
+ with g_pathmgr.open(ckpt_path, "rb") as f:
+ checkpoint = torch.load(f, map_location="cpu")
+ load_state_dict_into_model(
+ model=self.model,
+ state_dict=checkpoint["model"],
+ ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters,
+ )
+
+ self.optim.optimizer.load_state_dict(checkpoint["optimizer"])
+ self.loss.load_state_dict(checkpoint["loss"], strict=True)
+ self.epoch = checkpoint["epoch"]
+ self.steps = checkpoint["steps"]
+ self.ckpt_time_elapsed = checkpoint.get("time_elapsed")
+
+ if self.optim_conf.amp.enabled and "scaler" in checkpoint:
+ self.scaler.load_state_dict(checkpoint["scaler"])
+
+ self.best_meter_values = checkpoint.get("best_meter_values", {})
+
+ if "train_dataset" in checkpoint and self.train_dataset is not None:
+ self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"])
+
+ def is_intermediate_val_epoch(self, epoch):
+ return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1
+
+ def _step(
+ self,
+ batch: BatchedVideoDatapoint,
+ model: nn.Module,
+ phase: str,
+ ):
+
+ outputs = model(batch)
+ targets = batch.masks
+ batch_size = len(batch.img_batch)
+
+ key = batch.dict_key # key for dataset
+ loss = self.loss[key](outputs, targets)
+ loss_str = f"Losses/{phase}_{key}_loss"
+
+ loss_log_str = os.path.join("Step_Losses", loss_str)
+
+ # loss contains multiple sub-components we wish to log
+ step_losses = {}
+ if isinstance(loss, dict):
+ step_losses.update(
+ {f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()}
+ )
+ loss = self._log_loss_detailed_and_return_core_loss(
+ loss, loss_log_str, self.steps[phase]
+ )
+
+ if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0:
+ self.logger.log(
+ loss_log_str,
+ loss,
+ self.steps[phase],
+ )
+
+ self.steps[phase] += 1
+
+ ret_tuple = {loss_str: loss}, batch_size, step_losses
+
+ if phase in self.meters and key in self.meters[phase]:
+ meters_dict = self.meters[phase][key]
+ if meters_dict is not None:
+ for _, meter in meters_dict.items():
+ meter.update(
+ find_stages=outputs,
+ find_metadatas=batch.metadata,
+ )
+
+ return ret_tuple
+
+ def run(self):
+ assert self.mode in ["train", "train_only", "val"]
+ if self.mode == "train":
+ if self.epoch > 0:
+ logging.info(f"Resuming training from epoch: {self.epoch}")
+ # resuming from a checkpoint
+ if self.is_intermediate_val_epoch(self.epoch - 1):
+ logging.info("Running previous val epoch")
+ self.epoch -= 1
+ self.run_val()
+ self.epoch += 1
+ self.run_train()
+ self.run_val()
+ elif self.mode == "val":
+ self.run_val()
+ elif self.mode == "train_only":
+ self.run_train()
+
+ def _setup_dataloaders(self):
+ self.train_dataset = None
+ self.val_dataset = None
+
+ if self.mode in ["train", "val"]:
+ self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None))
+
+ if self.mode in ["train", "train_only"]:
+ self.train_dataset = instantiate(self.data_conf.train)
+
+ def run_train(self):
+
+ while self.epoch < self.max_epochs:
+ dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))
+ barrier()
+ outs = self.train_epoch(dataloader)
+ self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
+
+ # log train to text file.
+ if self.distributed_rank == 0:
+ with g_pathmgr.open(
+ os.path.join(self.logging_conf.log_dir, "train_stats.json"),
+ "a",
+ ) as f:
+ f.write(json.dumps(outs) + "\n")
+
+ # Save checkpoint before validating
+ self.save_checkpoint(self.epoch + 1)
+
+ del dataloader
+ gc.collect()
+
+ # Run val, not running on last epoch since will run after the
+ # loop anyway
+ if self.is_intermediate_val_epoch(self.epoch):
+ self.run_val()
+
+ if self.distributed_rank == 0:
+ self.best_meter_values.update(self._get_trainer_state("train"))
+ with g_pathmgr.open(
+ os.path.join(self.logging_conf.log_dir, "best_stats.json"),
+ "a",
+ ) as f:
+ f.write(json.dumps(self.best_meter_values) + "\n")
+
+ self.epoch += 1
+ # epoch was incremented in the loop but the val step runs out of the loop
+ self.epoch -= 1
+
+ def run_val(self):
+ if not self.val_dataset:
+ return
+
+ dataloader = self.val_dataset.get_loader(epoch=int(self.epoch))
+ outs = self.val_epoch(dataloader, phase=Phase.VAL)
+ del dataloader
+ gc.collect()
+ self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
+
+ if self.distributed_rank == 0:
+ with g_pathmgr.open(
+ os.path.join(self.logging_conf.log_dir, "val_stats.json"),
+ "a",
+ ) as f:
+ f.write(json.dumps(outs) + "\n")
+
+ def val_epoch(self, val_loader, phase):
+ batch_time = AverageMeter("Batch Time", self.device, ":.2f")
+ data_time = AverageMeter("Data Time", self.device, ":.2f")
+ mem = MemMeter("Mem (GB)", self.device, ":.2f")
+
+ iters_per_epoch = len(val_loader)
+
+ curr_phases = [phase]
+ curr_models = [self.model]
+
+ loss_names = []
+ for p in curr_phases:
+ for key in self.loss.keys():
+ loss_names.append(f"Losses/{p}_{key}_loss")
+
+ loss_mts = OrderedDict(
+ [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
+ )
+ extra_loss_mts = {}
+
+ for model in curr_models:
+ model.eval()
+ if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"):
+ unwrap_ddp_if_wrapped(model).on_validation_epoch_start()
+
+ progress = ProgressMeter(
+ iters_per_epoch,
+ [batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()],
+ self._get_meters(curr_phases),
+ prefix="Val Epoch: [{}]".format(self.epoch),
+ )
+
+ end = time.time()
+
+ for data_iter, batch in enumerate(val_loader):
+
+ # measure data loading time
+ data_time.update(time.time() - end)
+
+ batch = batch.to(self.device, non_blocking=True)
+
+ # compute output
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(
+ enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
+ dtype=(
+ get_amp_type(self.optim_conf.amp.amp_dtype)
+ if self.optim_conf
+ else None
+ ),
+ ):
+ for phase, model in zip(curr_phases, curr_models):
+ loss_dict, batch_size, extra_losses = self._step(
+ batch,
+ model,
+ phase,
+ )
+
+ assert len(loss_dict) == 1
+ loss_key, loss = loss_dict.popitem()
+
+ loss_mts[loss_key].update(loss.item(), batch_size)
+
+ for k, v in extra_losses.items():
+ if k not in extra_loss_mts:
+ extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e")
+ extra_loss_mts[k].update(v.item(), batch_size)
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ self.time_elapsed_meter.update(
+ time.time() - self.start_time + self.ckpt_time_elapsed
+ )
+
+ if torch.cuda.is_available():
+ mem.update(reset_peak_usage=True)
+
+ if data_iter % self.logging_conf.log_freq == 0:
+ progress.display(data_iter)
+
+ if data_iter % self.logging_conf.log_scalar_frequency == 0:
+ # Log progress meters.
+ for progress_meter in progress.meters:
+ self.logger.log(
+ os.path.join("Step_Stats", phase, progress_meter.name),
+ progress_meter.val,
+ self.steps[Phase.VAL],
+ )
+
+ if data_iter % 10 == 0:
+ dist.barrier()
+
+ self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch
+ self._log_timers(phase)
+ for model in curr_models:
+ if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"):
+ unwrap_ddp_if_wrapped(model).on_validation_epoch_end()
+
+ out_dict = self._log_meters_and_save_best_ckpts(curr_phases)
+
+ for k, v in loss_mts.items():
+ out_dict[k] = v.avg
+ for k, v in extra_loss_mts.items():
+ out_dict[k] = v.avg
+
+ for phase in curr_phases:
+ out_dict.update(self._get_trainer_state(phase))
+ self._reset_meters(curr_phases)
+ logging.info(f"Meters: {out_dict}")
+ return out_dict
+
+ def _get_trainer_state(self, phase):
+ return {
+ "Trainer/where": self.where,
+ "Trainer/epoch": self.epoch,
+ f"Trainer/steps_{phase}": self.steps[phase],
+ }
+
+ def train_epoch(self, train_loader):
+
+ # Init stat meters
+ batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f")
+ data_time_meter = AverageMeter("Data Time", self.device, ":.2f")
+ mem_meter = MemMeter("Mem (GB)", self.device, ":.2f")
+ data_times = []
+ phase = Phase.TRAIN
+
+ iters_per_epoch = len(train_loader)
+
+ loss_names = []
+ for batch_key in self.loss.keys():
+ loss_names.append(f"Losses/{phase}_{batch_key}_loss")
+
+ loss_mts = OrderedDict(
+ [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
+ )
+ extra_loss_mts = {}
+
+ progress = ProgressMeter(
+ iters_per_epoch,
+ [
+ batch_time_meter,
+ data_time_meter,
+ mem_meter,
+ self.time_elapsed_meter,
+ *loss_mts.values(),
+ ],
+ self._get_meters([phase]),
+ prefix="Train Epoch: [{}]".format(self.epoch),
+ )
+
+ # Model training loop
+ self.model.train()
+ end = time.time()
+
+ for data_iter, batch in enumerate(train_loader):
+ # measure data loading time
+ data_time_meter.update(time.time() - end)
+ data_times.append(data_time_meter.val)
+ batch = batch.to(
+ self.device, non_blocking=True
+ ) # move tensors in a tensorclass
+
+ try:
+ self._run_step(batch, phase, loss_mts, extra_loss_mts)
+
+ # compute gradient and do optim step
+ exact_epoch = self.epoch + float(data_iter) / iters_per_epoch
+ self.where = float(exact_epoch) / self.max_epochs
+ assert self.where <= 1 + self.EPSILON
+ if self.where < 1.0:
+ self.optim.step_schedulers(
+ self.where, step=int(exact_epoch * iters_per_epoch)
+ )
+ else:
+ logging.warning(
+ f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]."
+ )
+
+ # Log schedulers
+ if data_iter % self.logging_conf.log_scalar_frequency == 0:
+ for j, param_group in enumerate(self.optim.optimizer.param_groups):
+ for option in self.optim.schedulers[j]:
+ optim_prefix = (
+ "" + f"{j}_"
+ if len(self.optim.optimizer.param_groups) > 1
+ else ""
+ )
+ self.logger.log(
+ os.path.join("Optim", f"{optim_prefix}", option),
+ param_group[option],
+ self.steps[phase],
+ )
+
+ # Clipping gradients and detecting diverging gradients
+ if self.gradient_clipper is not None:
+ self.scaler.unscale_(self.optim.optimizer)
+ self.gradient_clipper(model=self.model)
+
+ if self.gradient_logger is not None:
+ self.gradient_logger(
+ self.model, rank=self.distributed_rank, where=self.where
+ )
+
+ # Optimizer step: the scaler will make sure gradients are not
+ # applied if the gradients are infinite
+ self.scaler.step(self.optim.optimizer)
+ self.scaler.update()
+
+ # measure elapsed time
+ batch_time_meter.update(time.time() - end)
+ end = time.time()
+
+ self.time_elapsed_meter.update(
+ time.time() - self.start_time + self.ckpt_time_elapsed
+ )
+
+ mem_meter.update(reset_peak_usage=True)
+ if data_iter % self.logging_conf.log_freq == 0:
+ progress.display(data_iter)
+
+ if data_iter % self.logging_conf.log_scalar_frequency == 0:
+ # Log progress meters.
+ for progress_meter in progress.meters:
+ self.logger.log(
+ os.path.join("Step_Stats", phase, progress_meter.name),
+ progress_meter.val,
+ self.steps[phase],
+ )
+
+ # Catching NaN/Inf errors in the loss
+ except FloatingPointError as e:
+ raise e
+
+ self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch
+ self._log_timers(Phase.TRAIN)
+ self._log_sync_data_times(Phase.TRAIN, data_times)
+
+ out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN])
+
+ for k, v in loss_mts.items():
+ out_dict[k] = v.avg
+ for k, v in extra_loss_mts.items():
+ out_dict[k] = v.avg
+ out_dict.update(self._get_trainer_state(phase))
+ logging.info(f"Losses and meters: {out_dict}")
+ self._reset_meters([phase])
+ return out_dict
+
+ def _log_sync_data_times(self, phase, data_times):
+ data_times = all_reduce_max(torch.tensor(data_times)).tolist()
+ steps = range(self.steps[phase] - len(data_times), self.steps[phase])
+ for step, data_time in zip(steps, data_times):
+ if step % self.logging_conf.log_scalar_frequency == 0:
+ self.logger.log(
+ os.path.join("Step_Stats", phase, "Data Time Synced"),
+ data_time,
+ step,
+ )
+
+ def _run_step(
+ self,
+ batch: BatchedVideoDatapoint,
+ phase: str,
+ loss_mts: Dict[str, AverageMeter],
+ extra_loss_mts: Dict[str, AverageMeter],
+ raise_on_error: bool = True,
+ ):
+ """
+ Run the forward / backward
+ """
+
+ # it's important to set grads to None, especially with Adam since 0
+ # grads will also update a model even if the step doesn't produce
+ # gradients
+ self.optim.zero_grad(set_to_none=True)
+ with torch.cuda.amp.autocast(
+ enabled=self.optim_conf.amp.enabled,
+ dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
+ ):
+ loss_dict, batch_size, extra_losses = self._step(
+ batch,
+ self.model,
+ phase,
+ )
+
+ assert len(loss_dict) == 1
+ loss_key, loss = loss_dict.popitem()
+
+ if not math.isfinite(loss.item()):
+ error_msg = f"Loss is {loss.item()}, attempting to stop training"
+ logging.error(error_msg)
+ if raise_on_error:
+ raise FloatingPointError(error_msg)
+ else:
+ return
+
+ self.scaler.scale(loss).backward()
+ loss_mts[loss_key].update(loss.item(), batch_size)
+ for extra_loss_key, extra_loss in extra_losses.items():
+ if extra_loss_key not in extra_loss_mts:
+ extra_loss_mts[extra_loss_key] = AverageMeter(
+ extra_loss_key, self.device, ":.2e"
+ )
+ extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size)
+
+ def _log_meters_and_save_best_ckpts(self, phases: List[str]):
+ logging.info("Synchronizing meters")
+ out_dict = {}
+ checkpoint_save_keys = []
+ for key, meter in self._get_meters(phases).items():
+ meter_output = meter.compute_synced()
+ is_better_check = getattr(meter, "is_better", None)
+
+ for meter_subkey, meter_value in meter_output.items():
+ out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value
+
+ if is_better_check is None:
+ continue
+
+ tracked_meter_key = os.path.join(key, meter_subkey)
+ if tracked_meter_key not in self.best_meter_values or is_better_check(
+ meter_value,
+ self.best_meter_values[tracked_meter_key],
+ ):
+ self.best_meter_values[tracked_meter_key] = meter_value
+
+ if (
+ self.checkpoint_conf.save_best_meters is not None
+ and key in self.checkpoint_conf.save_best_meters
+ ):
+ checkpoint_save_keys.append(tracked_meter_key.replace("/", "_"))
+
+ if len(checkpoint_save_keys) > 0:
+ self.save_checkpoint(self.epoch + 1, checkpoint_save_keys)
+
+ return out_dict
+
+ def _log_timers(self, phase):
+ time_remaining = 0
+ epochs_remaining = self.max_epochs - self.epoch - 1
+ val_epochs_remaining = sum(
+ n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs)
+ )
+
+ # Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with
+ # the end epoch.
+ if (self.max_epochs - 1) % self.val_epoch_freq != 0:
+ val_epochs_remaining += 1
+
+ # Remove the current val run from estimate
+ if phase == Phase.VAL:
+ val_epochs_remaining -= 1
+
+ time_remaining += (
+ epochs_remaining * self.est_epoch_time[Phase.TRAIN]
+ + val_epochs_remaining * self.est_epoch_time[Phase.VAL]
+ )
+
+ self.logger.log(
+ os.path.join("Step_Stats", phase, self.time_elapsed_meter.name),
+ self.time_elapsed_meter.val,
+ self.steps[phase],
+ )
+
+ logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}")
+
+ def _reset_meters(self, phases: str) -> None:
+ for meter in self._get_meters(phases).values():
+ meter.reset()
+
+ def _check_val_key_match(self, val_keys, phase):
+ if val_keys is not None:
+ # Check if there are any duplicates
+ assert len(val_keys) == len(
+ set(val_keys)
+ ), f"Duplicate keys in val datasets, keys: {val_keys}"
+
+ # Check that the keys match the meter keys
+ if self.meters_conf is not None and phase in self.meters_conf:
+ assert set(val_keys) == set(self.meters_conf[phase].keys()), (
+ f"Keys in val datasets do not match the keys in meters."
+ f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}"
+ f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}"
+ )
+
+ if self.loss_conf is not None:
+ loss_keys = set(self.loss_conf.keys()) - set(["all"])
+ assert all([k in loss_keys for k in val_keys]), (
+ f"Keys in val datasets do not match the keys in losses."
+ f"\nMissing in losses: {set(val_keys) - loss_keys}"
+ f"\nMissing in val datasets: {loss_keys - set(val_keys)}"
+ )
+
+ def _setup_components(self):
+
+ # Get the keys for all the val datasets, if any
+ val_phase = Phase.VAL
+ val_keys = None
+ if self.data_conf.get(val_phase, None) is not None:
+ val_keys = collect_dict_keys(self.data_conf[val_phase])
+ # Additional checks on the sanity of the config for val datasets
+ self._check_val_key_match(val_keys, phase=val_phase)
+
+ logging.info("Setting up components: Model, loss, optim, meters etc.")
+ self.epoch = 0
+ self.steps = {Phase.TRAIN: 0, Phase.VAL: 0}
+
+ self.logger = Logger(self.logging_conf)
+
+ self.model = instantiate(self.model_conf, _convert_="all")
+ print_model_summary(self.model)
+
+ self.loss = None
+ if self.loss_conf:
+ self.loss = {
+ key: el # wrap_base_loss(el)
+ for (key, el) in instantiate(self.loss_conf, _convert_="all").items()
+ }
+ self.loss = nn.ModuleDict(self.loss)
+
+ self.meters = {}
+ self.best_meter_values = {}
+ if self.meters_conf:
+ self.meters = instantiate(self.meters_conf, _convert_="all")
+
+ self.scaler = torch.amp.GradScaler(
+ self.device,
+ enabled=self.optim_conf.amp.enabled if self.optim_conf else False,
+ )
+
+ self.gradient_clipper = (
+ instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None
+ )
+ self.gradient_logger = (
+ instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None
+ )
+
+ logging.info("Finished setting up components: Model, loss, optim, meters etc.")
+
+ def _construct_optimizers(self):
+ self.optim = construct_optimizer(
+ self.model,
+ self.optim_conf.optimizer,
+ self.optim_conf.options,
+ self.optim_conf.param_group_modifiers,
+ )
+
+ def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step):
+ core_loss = loss.pop(CORE_LOSS_KEY)
+ if step % self.logging_conf.log_scalar_frequency == 0:
+ for k in loss:
+ log_str = os.path.join(loss_str, k)
+ self.logger.log(log_str, loss[k], step)
+ return core_loss
+
+
+def print_model_summary(model: torch.nn.Module, log_dir: str = ""):
+ """
+ Prints the model and the number of parameters in the model.
+ # Multiple packages provide this info in a nice table format
+ # However, they need us to provide an `input` (as they also write down the output sizes)
+ # Our models are complex, and a single input is restrictive.
+ # https://github.com/sksq96/pytorch-summary
+ # https://github.com/nmhkahn/torchsummaryX
+ """
+ if get_rank() != 0:
+ return
+ param_kwargs = {}
+ trainable_parameters = sum(
+ p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad
+ )
+ total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs))
+ non_trainable_parameters = total_parameters - trainable_parameters
+ logging.info("==" * 10)
+ logging.info(f"Summary for model {type(model)}")
+ logging.info(f"Model is {model}")
+ logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}")
+ logging.info(
+ f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}"
+ )
+ logging.info(
+ f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}"
+ )
+ logging.info("==" * 10)
+
+ if log_dir:
+ output_fpath = os.path.join(log_dir, "model.txt")
+ with g_pathmgr.open(output_fpath, "w") as f:
+ print(model, file=f)
+
+
+PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
+
+
+def get_human_readable_count(number: int) -> str:
+ """
+ Abbreviates an integer number with K, M, B, T for thousands, millions,
+ billions and trillions, respectively.
+ Examples:
+ >>> get_human_readable_count(123)
+ '123 '
+ >>> get_human_readable_count(1234) # (one thousand)
+ '1.2 K'
+ >>> get_human_readable_count(2e6) # (two million)
+ '2.0 M'
+ >>> get_human_readable_count(3e9) # (three billion)
+ '3.0 B'
+ >>> get_human_readable_count(4e14) # (four hundred trillion)
+ '400 T'
+ >>> get_human_readable_count(5e15) # (more than trillion)
+ '5,000 T'
+ Args:
+ number: a positive integer number
+ Return:
+ A string formatted according to the pattern described above.
+ """
+ assert number >= 0
+ labels = PARAMETER_NUM_UNITS
+ num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
+ num_groups = int(np.ceil(num_digits / 3))
+ num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
+ shift = -3 * (num_groups - 1)
+ number = number * (10**shift)
+ index = num_groups - 1
+ if index < 1 or number >= 100:
+ return f"{int(number):,d} {labels[index]}"
+ else:
+ return f"{number:,.1f} {labels[index]}"
diff --git a/phantom/submodules/sam2/training/utils/__init__.py b/phantom/submodules/sam2/training/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/phantom/submodules/sam2/training/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/phantom/submodules/sam2/training/utils/checkpoint_utils.py b/phantom/submodules/sam2/training/utils/checkpoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f76689f341dedc485c0c32d096fb5b2e8337bea9
--- /dev/null
+++ b/phantom/submodules/sam2/training/utils/checkpoint_utils.py
@@ -0,0 +1,361 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import contextlib
+import fnmatch
+import logging
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
+
+import numpy as np
+import torch
+import torch.nn as nn
+from iopath.common.file_io import g_pathmgr
+from torch.jit._script import RecursiveScriptModule
+
+
+def unix_pattern_to_parameter_names(
+ constraints: List[str], all_parameter_names: Sequence[str]
+) -> Union[None, Set[str]]:
+ """
+ Go through the list of parameter names and select those that match
+ any of the provided constraints
+ """
+ parameter_names = []
+ for param_name in constraints:
+ matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
+ assert (
+ len(matching_parameters) > 0
+ ), f"param_names {param_name} don't match any param in the given names."
+ parameter_names.append(matching_parameters)
+ return set.union(*parameter_names)
+
+
+def filter_params_matching_unix_pattern(
+ patterns: List[str], state_dict: Dict[str, torch.Tensor]
+) -> Dict[str, torch.Tensor]:
+ """
+ Remove from the state dictionary the parameters matching the provided unix patterns
+
+ Args:
+ patterns: the list of unix patterns to exclude
+ state_dict: the dictionary to filter
+
+ Returns:
+ A new state dictionary
+ """
+ if len(patterns) == 0:
+ return {}
+
+ all_keys = list(state_dict.keys())
+ included_keys = unix_pattern_to_parameter_names(patterns, all_keys)
+ return {k: state_dict[k] for k in included_keys}
+
+
+def exclude_params_matching_unix_pattern(
+ patterns: List[str], state_dict: Dict[str, torch.Tensor]
+) -> Dict[str, torch.Tensor]:
+ """
+ Remove from the state dictionary the parameters matching the provided unix patterns
+
+ Args:
+ patterns: the list of unix patterns to exclude
+ state_dict: the dictionary to filter
+
+ Returns:
+ A new state dictionary
+ """
+ if len(patterns) == 0:
+ return state_dict
+
+ all_keys = list(state_dict.keys())
+ excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys)
+ return {k: v for k, v in state_dict.items() if k not in excluded_keys}
+
+
+def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]):
+ keys = []
+ trace = []
+ for k, v in state_dict.items():
+ keys.append(k)
+ trace.append(v.sum().item())
+ trace = np.array(trace)[np.argsort(keys)]
+ return trace
+
+
+def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]):
+ """
+ Verifies that all the parameters matching the provided patterns
+ are frozen - this acts as a safeguard when ignoring parameter
+ when saving checkpoints - if the parameters are in fact trainable
+ """
+ if not patterns:
+ return
+
+ frozen_state_dict = filter_params_matching_unix_pattern(
+ patterns=patterns, state_dict=model.state_dict()
+ )
+ non_frozen_keys = {
+ n
+ for n, p in model.named_parameters()
+ if n in frozen_state_dict and p.requires_grad
+ }
+ if non_frozen_keys:
+ raise ValueError(
+ f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}"
+ )
+
+
+@contextlib.contextmanager
+def with_check_parameter_frozen(
+ model: nn.Module, patterns: List[str], disabled: bool = True
+):
+ """
+ Context manager that inspects a model surrounding a piece of code
+ and verifies if the model has been updated by this piece of code
+
+ The function will raise an exception if the model has been updated
+ on at least one of the parameter that matches one of the pattern
+
+ Args:
+ model: the model that might have been updated
+ patterns: for the parameters we want to observe
+ allowed:
+ """
+ if not patterns or disabled:
+ yield
+ return
+
+ frozen_state_dict = filter_params_matching_unix_pattern(
+ patterns=patterns, state_dict=model.state_dict()
+ )
+ summary_before = _get_state_dict_summary(frozen_state_dict)
+
+ yield
+
+ frozen_state_dict = filter_params_matching_unix_pattern(
+ patterns=patterns, state_dict=model.state_dict()
+ )
+ summary_after = _get_state_dict_summary(frozen_state_dict)
+
+ if not np.allclose(summary_before, summary_after, atol=1e-6):
+ raise ValueError(
+ f"""
+ The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`.
+ You can resolve this error by either initializing those parameters from within the model definition
+ or using the flag `trainer.checkpoint.initialize_after_preemption` to True.
+ """
+ )
+
+
+class CkptExcludeKernel:
+ """
+ Removes the keys from the given model state_dict that match the key_pattern.
+
+ Args:
+ key_pattern: Patterns used to select the keys in the state_dict
+ that are eligible for this kernel.
+ """
+
+ def __init__(self, key_pattern: List[str]):
+ self.key_pattern = key_pattern
+
+ def __call__(self, state_dict: Dict):
+ """
+ Args:
+ state_dict: A dictionary representing the given checkpoint's state dict.
+ """
+ if len(self.key_pattern) == 0:
+ return state_dict
+ exclude_keys = unix_pattern_to_parameter_names(
+ self.key_pattern, state_dict.keys()
+ )
+ return {k: v for k, v in state_dict.items() if k not in exclude_keys}
+
+
+def load_checkpoint(
+ path_list: List[str],
+ pick_recursive_keys: Optional[List[str]] = None,
+ map_location: str = "cpu",
+) -> Any:
+ """
+ Loads a checkpoint from the specified path.
+
+ Args:
+ path_list: A list of paths which contain the checkpoint. Each element
+ is tried (in order) until a file that exists is found. That file is then
+ used to read the checkpoint.
+ pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None.
+ For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"]
+ map_location (str): a function, torch.device, string or a dict specifying how to
+ remap storage locations
+
+ Returns: Model with the matchin pre-trained weights loaded.
+ """
+ path_exists = False
+ for path in path_list:
+ if g_pathmgr.exists(path):
+ path_exists = True
+ break
+
+ if not path_exists:
+ raise ValueError(f"No path exists in {path_list}")
+
+ with g_pathmgr.open(path, "rb") as f:
+ checkpoint = torch.load(f, map_location=map_location)
+
+ logging.info(f"Loaded checkpoint from {path}")
+ if pick_recursive_keys is not None:
+ for key in pick_recursive_keys:
+ checkpoint = checkpoint[key]
+ return checkpoint
+
+
+def get_state_dict(checkpoint, ckpt_state_dict_keys):
+ if isinstance(checkpoint, RecursiveScriptModule):
+ # This is a torchscript JIT model
+ return checkpoint.state_dict()
+ pre_train_dict = checkpoint
+ for i, key in enumerate(ckpt_state_dict_keys):
+ if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or (
+ isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict)
+ ):
+ key_str = (
+ '["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]'
+ )
+ raise KeyError(
+ f"'{key}' not found in checkpoint{key_str} "
+ f"with keys: {pre_train_dict.keys()}"
+ )
+ pre_train_dict = pre_train_dict[key]
+ return pre_train_dict
+
+
+def load_checkpoint_and_apply_kernels(
+ checkpoint_path: str,
+ checkpoint_kernels: List[Callable] = None,
+ ckpt_state_dict_keys: Tuple[str] = ("state_dict",),
+ map_location: str = "cpu",
+) -> nn.Module:
+ """
+ Performs checkpoint loading with a variety of pre-processing kernel applied in
+ sequence.
+
+ Args:
+ checkpoint_path (str): Path to the checkpoint.
+ checkpoint_kernels List(Callable): A list of checkpoint processing kernels
+ to apply in the specified order. Supported kernels include `CkptIncludeKernel`,
+ `CkptExcludeKernel`, etc. These kernels are applied in the
+ given order.
+ ckpt_state_dict_keys (str): Keys containing the model state dict.
+ map_location (str): a function, torch.device, string or a dict specifying how to
+ remap storage locations
+
+ Returns: Model with the matchin pre-trained weights loaded.
+ """
+ assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format(
+ checkpoint_path
+ )
+
+ # Load the checkpoint on CPU to avoid GPU mem spike.
+ with g_pathmgr.open(checkpoint_path, "rb") as f:
+ checkpoint = torch.load(f, map_location=map_location)
+
+ pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys)
+
+ # Not logging into info etc since it's a huge log
+ logging.debug(
+ "Loaded Checkpoint State Dict pre-kernel application: %s"
+ % str(", ".join(list(pre_train_dict.keys())))
+ )
+ # Apply kernels
+ if checkpoint_kernels is not None:
+ for f in checkpoint_kernels:
+ pre_train_dict = f(state_dict=pre_train_dict)
+
+ logging.debug(
+ "Loaded Checkpoint State Dict Post-kernel application %s"
+ % str(", ".join(list(pre_train_dict.keys())))
+ )
+
+ return pre_train_dict
+
+
+def check_load_state_dict_errors(
+ missing_keys,
+ unexpected_keys,
+ strict: bool,
+ ignore_missing_keys: List[str] = None,
+ ignore_unexpected_keys: List[str] = None,
+):
+ if ignore_missing_keys is not None and len(ignore_missing_keys) > 0:
+ ignored_keys = unix_pattern_to_parameter_names(
+ ignore_missing_keys, missing_keys
+ )
+ missing_keys = [key for key in missing_keys if key not in ignored_keys]
+
+ if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0:
+ ignored_unexpected_keys = unix_pattern_to_parameter_names(
+ ignore_unexpected_keys, unexpected_keys
+ )
+ unexpected_keys = [
+ key for key in unexpected_keys if key not in ignored_unexpected_keys
+ ]
+
+ err = "State key mismatch."
+ if unexpected_keys:
+ err += f" Unexpected keys: {unexpected_keys}."
+ if missing_keys:
+ err += f" Missing keys: {missing_keys}."
+
+ if unexpected_keys or missing_keys:
+ logging.warning(err)
+ if unexpected_keys or strict:
+ raise KeyError(err)
+
+
+def load_state_dict_into_model(
+ state_dict: Dict,
+ model: nn.Module,
+ strict: bool = True,
+ ignore_missing_keys: List[str] = None,
+ ignore_unexpected_keys: List[str] = None,
+ checkpoint_kernels: List[Callable] = None,
+):
+ """
+ Loads a state dict into the given model.
+
+ Args:
+ state_dict: A dictionary containing the model's
+ state dict, or a subset if strict is False
+ model: Model to load the checkpoint weights into
+ strict: raise if the state_dict has missing state keys
+ ignore_missing_keys: unix pattern of keys to ignore
+ """
+ # Apply kernels
+ if checkpoint_kernels is not None:
+ for f in checkpoint_kernels:
+ state_dict = f(state_dict=state_dict)
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+
+ check_load_state_dict_errors(
+ missing_keys,
+ unexpected_keys,
+ strict=strict,
+ ignore_missing_keys=ignore_missing_keys,
+ ignore_unexpected_keys=ignore_unexpected_keys,
+ )
+ return model
diff --git a/phantom/submodules/sam2/training/utils/data_utils.py b/phantom/submodules/sam2/training/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbd0115355c97a27c601a833985466e558063b91
--- /dev/null
+++ b/phantom/submodules/sam2/training/utils/data_utils.py
@@ -0,0 +1,179 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from PIL import Image as PILImage
+from tensordict import tensorclass
+
+
+@tensorclass
+class BatchedVideoMetaData:
+ """
+ This class represents metadata about a batch of videos.
+ Attributes:
+ unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id)
+ frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch.
+ """
+
+ unique_objects_identifier: torch.LongTensor
+ frame_orig_size: torch.LongTensor
+
+
+@tensorclass
+class BatchedVideoDatapoint:
+ """
+ This class represents a batch of videos with associated annotations and metadata.
+ Attributes:
+ img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
+ obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
+ masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
+ metadata: An instance of BatchedVideoMetaData containing metadata about the batch.
+ dict_key: A string key used to identify the batch.
+ """
+
+ img_batch: torch.FloatTensor
+ obj_to_frame_idx: torch.IntTensor
+ masks: torch.BoolTensor
+ metadata: BatchedVideoMetaData
+
+ dict_key: str
+
+ def pin_memory(self, device=None):
+ return self.apply(torch.Tensor.pin_memory, device=device)
+
+ @property
+ def num_frames(self) -> int:
+ """
+ Returns the number of frames per video.
+ """
+ return self.batch_size[0]
+
+ @property
+ def num_videos(self) -> int:
+ """
+ Returns the number of videos in the batch.
+ """
+ return self.img_batch.shape[1]
+
+ @property
+ def flat_obj_to_img_idx(self) -> torch.IntTensor:
+ """
+ Returns a flattened tensor containing the object to img index.
+ The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW]
+ """
+ frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1)
+ flat_idx = video_idx * self.num_frames + frame_idx
+ return flat_idx
+
+ @property
+ def flat_img_batch(self) -> torch.FloatTensor:
+ """
+ Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
+ """
+
+ return self.img_batch.transpose(0, 1).flatten(0, 1)
+
+
+@dataclass
+class Object:
+ # Id of the object in the media
+ object_id: int
+ # Index of the frame in the media (0 if single image)
+ frame_index: int
+ segment: Union[torch.Tensor, dict] # RLE dict or binary mask
+
+
+@dataclass
+class Frame:
+ data: Union[torch.Tensor, PILImage.Image]
+ objects: List[Object]
+
+
+@dataclass
+class VideoDatapoint:
+ """Refers to an image/video and all its annotations"""
+
+ frames: List[Frame]
+ video_id: int
+ size: Tuple[int, int]
+
+
+def collate_fn(
+ batch: List[VideoDatapoint],
+ dict_key,
+) -> BatchedVideoDatapoint:
+ """
+ Args:
+ batch: A list of VideoDatapoint instances.
+ dict_key (str): A string key used to identify the batch.
+ """
+ img_batch = []
+ for video in batch:
+ img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)]
+
+ img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4))
+ T = img_batch.shape[0]
+ # Prepare data structures for sequential processing. Per-frame processing but batched across videos.
+ step_t_objects_identifier = [[] for _ in range(T)]
+ step_t_frame_orig_size = [[] for _ in range(T)]
+
+ step_t_masks = [[] for _ in range(T)]
+ step_t_obj_to_frame_idx = [
+ [] for _ in range(T)
+ ] # List to store frame indices for each time step
+
+ for video_idx, video in enumerate(batch):
+ orig_video_id = video.video_id
+ orig_frame_size = video.size
+ for t, frame in enumerate(video.frames):
+ objects = frame.objects
+ for obj in objects:
+ orig_obj_id = obj.object_id
+ orig_frame_idx = obj.frame_index
+ step_t_obj_to_frame_idx[t].append(
+ torch.tensor([t, video_idx], dtype=torch.int)
+ )
+ step_t_masks[t].append(obj.segment.to(torch.bool))
+ step_t_objects_identifier[t].append(
+ torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx])
+ )
+ step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size))
+
+ obj_to_frame_idx = torch.stack(
+ [
+ torch.stack(obj_to_frame_idx, dim=0)
+ for obj_to_frame_idx in step_t_obj_to_frame_idx
+ ],
+ dim=0,
+ )
+ masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0)
+ objects_identifier = torch.stack(
+ [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0
+ )
+ frame_orig_size = torch.stack(
+ [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0
+ )
+ return BatchedVideoDatapoint(
+ img_batch=img_batch,
+ obj_to_frame_idx=obj_to_frame_idx,
+ masks=masks,
+ metadata=BatchedVideoMetaData(
+ unique_objects_identifier=objects_identifier,
+ frame_orig_size=frame_orig_size,
+ ),
+ dict_key=dict_key,
+ batch_size=[T],
+ )
diff --git a/phantom/submodules/sam2/training/utils/distributed.py b/phantom/submodules/sam2/training/utils/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..f614b40427f40350c4df9e695cd327cb4d6a96f6
--- /dev/null
+++ b/phantom/submodules/sam2/training/utils/distributed.py
@@ -0,0 +1,576 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import datetime
+import functools
+import io
+import logging
+import os
+import random
+import tempfile
+import time
+from typing import Any, Callable, List, Tuple
+
+import torch
+import torch.autograd as autograd
+import torch.distributed as dist
+
+
+# Default to GPU 0
+_cuda_device_index: int = 0
+
+# Setting _cuda_device_index to -1 internally implies that we should use CPU
+_CPU_DEVICE_INDEX = -1
+_PRIMARY_RANK = 0
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+
+ if dist.get_backend() == "nccl":
+ # Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
+ # being much slower than others causing a timeout (which can happen in relation
+ # or LVIS class mAP evaluation).
+ timeout = 43200
+ return dist.new_group(
+ backend="gloo",
+ timeout=datetime.timedelta(seconds=timeout),
+ )
+
+ return dist.group.WORLD
+
+
+def is_main_process():
+ """Return true if the current process is the main one"""
+ return get_rank() == 0
+
+
+def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
+ `all_gather` above, but using filesystem instead of collective ops.
+
+ If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
+ (and other ranks will have an empty list).
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ print("gathering via files")
+ cpu_group = _get_global_gloo_group()
+
+ # if unspecified, we will save to the current python file dir
+ if filesys_save_dir is not None:
+ save_dir = filesys_save_dir
+ elif "EXP_DIR" in os.environ:
+ save_dir = os.environ["EXP_DIR"]
+ else:
+ # try the same directory where the code is stored
+ save_dir = filesys_save_dir or os.path.dirname(__file__)
+ save_dir = os.path.join(save_dir, "all_gather_via_filesys")
+ if is_main_process():
+ os.makedirs(save_dir, exist_ok=True)
+
+ # use a timestamp and salt to distinguish different all_gather
+ timestamp = int(time.time()) if is_main_process() else 0
+ salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
+ # broadcast the timestamp and salt across ranks
+ # (all-reduce will do the broadcasting since only rank 0 is non-zero)
+ timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
+ dist.all_reduce(timestamp_and_salt, group=cpu_group)
+ timestamp, salt = timestamp_and_salt.tolist()
+
+ # save the data to a file on the disk
+ rank_save = get_rank()
+ save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
+ save_data_path = os.path.join(save_dir, save_data_filename)
+ assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
+ torch.save(data, save_data_path)
+ dist.barrier(group=cpu_group)
+
+ # read the data from the files
+ data_list = []
+ if rank_save == 0 or not gather_to_rank_0_only:
+ for rank_load in range(world_size):
+ load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
+ load_data_path = os.path.join(save_dir, load_data_filename)
+ assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
+ data_list.append(torch.load(load_data_path))
+ dist.barrier(group=cpu_group)
+
+ # delete the saved file
+ os.remove(save_data_path)
+ return data_list
+
+
+def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
+ return all_gather_via_filesys(
+ data, filesys_save_dir, gather_to_rank_0_only=True
+ )
+
+ if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
+ return all_gather_via_filesys(data, filesys_save_dir)
+
+ cpu_group = None
+ if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
+ cpu_group = _get_global_gloo_group()
+
+ buffer = io.BytesIO()
+ torch.save(data, buffer)
+ data_view = buffer.getbuffer()
+ device = "cuda" if cpu_group is None else "cpu"
+ tensor = torch.ByteTensor(data_view).to(device)
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
+ size_list = [
+ torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
+ ]
+ if cpu_group is None:
+ dist.all_gather(size_list, local_size)
+ else:
+ print("gathering on cpu")
+ dist.all_gather(size_list, local_size, group=cpu_group)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+ assert isinstance(local_size.item(), int)
+ local_size = int(local_size.item())
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
+ if local_size != max_size:
+ padding = torch.empty(
+ size=(max_size - local_size,), dtype=torch.uint8, device=device
+ )
+ tensor = torch.cat((tensor, padding), dim=0)
+ if cpu_group is None:
+ dist.all_gather(tensor_list, tensor)
+ else:
+ dist.all_gather(tensor_list, tensor, group=cpu_group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
+ buffer = io.BytesIO(tensor.cpu().numpy())
+ obj = torch.load(buffer)
+ data_list.append(obj)
+
+ return data_list
+
+
+def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
+ """
+ For some backends, such as NCCL, communication only works if the
+ tensor is on the GPU. This helper function converts to the correct
+ device and returns the tensor + original device.
+ """
+ orig_device = "cpu" if not tensor.is_cuda else "gpu"
+ if (
+ torch.distributed.is_available()
+ and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
+ and not tensor.is_cuda
+ ):
+ tensor = tensor.cuda()
+ return (tensor, orig_device)
+
+
+def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
+ """
+ For some backends, such as NCCL, communication only works if the
+ tensor is on the GPU. This converts the tensor back to original device.
+ """
+ if tensor.is_cuda and orig_device == "cpu":
+ tensor = tensor.cpu()
+ return tensor
+
+
+def is_distributed_training_run() -> bool:
+ return (
+ torch.distributed.is_available()
+ and torch.distributed.is_initialized()
+ and (torch.distributed.get_world_size() > 1)
+ )
+
+
+def is_primary() -> bool:
+ """
+ Returns True if this is rank 0 of a distributed training job OR if it is
+ a single trainer job. Otherwise False.
+ """
+ return get_rank() == _PRIMARY_RANK
+
+
+def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Wrapper over torch.distributed.all_reduce for performing mean reduction
+ of tensor over all processes.
+ """
+ return all_reduce_op(
+ tensor,
+ torch.distributed.ReduceOp.SUM,
+ lambda t: t / torch.distributed.get_world_size(),
+ )
+
+
+def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Wrapper over torch.distributed.all_reduce for performing sum
+ reduction of tensor over all processes in both distributed /
+ non-distributed scenarios.
+ """
+ return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
+
+
+def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Wrapper over torch.distributed.all_reduce for performing min
+ reduction of tensor over all processes in both distributed /
+ non-distributed scenarios.
+ """
+ return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
+
+
+def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Wrapper over torch.distributed.all_reduce for performing min
+ reduction of tensor over all processes in both distributed /
+ non-distributed scenarios.
+ """
+ return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
+
+
+def all_reduce_op(
+ tensor: torch.Tensor,
+ op: torch.distributed.ReduceOp,
+ after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
+) -> torch.Tensor:
+ """
+ Wrapper over torch.distributed.all_reduce for performing
+ reduction of tensor over all processes in both distributed /
+ non-distributed scenarios.
+ """
+ if is_distributed_training_run():
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
+ torch.distributed.all_reduce(tensor, op)
+ if after_op_func is not None:
+ tensor = after_op_func(tensor)
+ tensor = convert_to_normal_tensor(tensor, orig_device)
+ return tensor
+
+
+def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
+ """
+ Wrapper over torch.distributed.all_gather for performing
+ 'gather' of 'tensor' over all processes in both distributed /
+ non-distributed scenarios.
+ """
+ if tensor.ndim == 0:
+ # 0 dim tensors cannot be gathered. so unsqueeze
+ tensor = tensor.unsqueeze(0)
+
+ if is_distributed_training_run():
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
+ gathered_tensors = [
+ torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
+ ]
+ torch.distributed.all_gather(gathered_tensors, tensor)
+ gathered_tensors = [
+ convert_to_normal_tensor(_tensor, orig_device)
+ for _tensor in gathered_tensors
+ ]
+ else:
+ gathered_tensors = [tensor]
+
+ return gathered_tensors
+
+
+def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
+ gathered_tensors = gather_tensors_from_all(tensor)
+ gathered_tensor = torch.cat(gathered_tensors, 0)
+ return gathered_tensor
+
+
+def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
+ """
+ Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
+ to all processes in both distributed / non-distributed scenarios.
+ """
+ if is_distributed_training_run():
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
+ torch.distributed.broadcast(tensor, src)
+ tensor = convert_to_normal_tensor(tensor, orig_device)
+ return tensor
+
+
+def barrier() -> None:
+ """
+ Wrapper over torch.distributed.barrier, returns without waiting
+ if the distributed process group is not initialized instead of throwing error.
+ """
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
+ return
+ torch.distributed.barrier()
+
+
+def get_world_size() -> int:
+ """
+ Simple wrapper for correctly getting worldsize in both distributed
+ / non-distributed settings
+ """
+ return (
+ torch.distributed.get_world_size()
+ if torch.distributed.is_available() and torch.distributed.is_initialized()
+ else 1
+ )
+
+
+def get_rank() -> int:
+ """
+ Simple wrapper for correctly getting rank in both distributed
+ / non-distributed settings
+ """
+ return (
+ torch.distributed.get_rank()
+ if torch.distributed.is_available() and torch.distributed.is_initialized()
+ else 0
+ )
+
+
+def get_primary_rank() -> int:
+ return _PRIMARY_RANK
+
+
+def set_cuda_device_index(idx: int) -> None:
+ global _cuda_device_index
+ _cuda_device_index = idx
+ torch.cuda.set_device(_cuda_device_index)
+
+
+def set_cpu_device() -> None:
+ global _cuda_device_index
+ _cuda_device_index = _CPU_DEVICE_INDEX
+
+
+def get_cuda_device_index() -> int:
+ return _cuda_device_index
+
+
+def init_distributed_data_parallel_model(
+ model: torch.nn.Module,
+ broadcast_buffers: bool = False,
+ find_unused_parameters: bool = True,
+ bucket_cap_mb: int = 25,
+) -> torch.nn.parallel.DistributedDataParallel:
+ global _cuda_device_index
+
+ if _cuda_device_index == _CPU_DEVICE_INDEX:
+ # CPU-only model, don't specify device
+ return torch.nn.parallel.DistributedDataParallel(
+ model,
+ broadcast_buffers=broadcast_buffers,
+ find_unused_parameters=find_unused_parameters,
+ bucket_cap_mb=bucket_cap_mb,
+ )
+ else:
+ # GPU model
+ return torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[_cuda_device_index],
+ output_device=_cuda_device_index,
+ broadcast_buffers=broadcast_buffers,
+ find_unused_parameters=find_unused_parameters,
+ bucket_cap_mb=bucket_cap_mb,
+ )
+
+
+def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
+ """Broadcast an object from a source to all workers.
+
+ Args:
+ obj: Object to broadcast, must be serializable
+ src: Source rank for broadcast (default is primary)
+ use_disk: If enabled, removes redundant CPU memory copies by writing to
+ disk
+ """
+ # Either broadcast from primary to the fleet (default),
+ # or use the src setting as the original rank
+ if get_rank() == src:
+ # Emit data
+ buffer = io.BytesIO()
+ torch.save(obj, buffer)
+ data_view = buffer.getbuffer()
+ length_tensor = torch.LongTensor([len(data_view)])
+ length_tensor = broadcast(length_tensor, src=src)
+ data_tensor = torch.ByteTensor(data_view)
+ data_tensor = broadcast(data_tensor, src=src)
+ else:
+ # Fetch from the source
+ length_tensor = torch.LongTensor([0])
+ length_tensor = broadcast(length_tensor, src=src)
+ data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
+ data_tensor = broadcast(data_tensor, src=src)
+ if use_disk:
+ with tempfile.TemporaryFile("r+b") as f:
+ f.write(data_tensor.numpy())
+ # remove reference to the data tensor and hope that Python garbage
+ # collects it
+ del data_tensor
+ f.seek(0)
+ obj = torch.load(f)
+ else:
+ buffer = io.BytesIO(data_tensor.numpy())
+ obj = torch.load(buffer)
+ return obj
+
+
+def all_gather_tensor(tensor: torch.Tensor, world_size=None):
+ if world_size is None:
+ world_size = get_world_size()
+ # make contiguous because NCCL won't gather the tensor otherwise
+ assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
+ tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
+ dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
+ tensor_all = [
+ convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all
+ ]
+ return tensor_all
+
+
+def all_gather_batch(tensors: List[torch.Tensor]):
+ """
+ Performs all_gather operation on the provided tensors.
+ """
+ # Queue the gathered tensors
+ world_size = get_world_size()
+ # There is no need for reduction in the single-proc case
+ if world_size == 1:
+ return tensors
+ tensor_list = []
+ output_tensor = []
+ for tensor in tensors:
+ tensor_all = all_gather_tensor(tensor, world_size)
+ tensor_list.append(tensor_all)
+
+ for tensor_all in tensor_list:
+ output_tensor.append(torch.cat(tensor_all, dim=0))
+ return output_tensor
+
+
+class GatherLayer(autograd.Function):
+ """
+ Gather tensors from all workers with support for backward propagation:
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
+ dist.all_gather(output, x)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ all_gradients = torch.stack(grads)
+ dist.all_reduce(all_gradients)
+ return all_gradients[dist.get_rank()]
+
+
+def all_gather_batch_with_grad(tensors):
+ """
+ Performs all_gather operation on the provided tensors.
+ Graph remains connected for backward grad computation.
+ """
+ # Queue the gathered tensors
+ world_size = get_world_size()
+ # There is no need for reduction in the single-proc case
+ if world_size == 1:
+ return tensors
+ tensor_list = []
+ output_tensor = []
+
+ for tensor in tensors:
+ tensor_all = GatherLayer.apply(tensor)
+ tensor_list.append(tensor_all)
+
+ for tensor_all in tensor_list:
+ output_tensor.append(torch.cat(tensor_all, dim=0))
+ return output_tensor
+
+
+def unwrap_ddp_if_wrapped(model):
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+ return model.module
+ return model
+
+
+def create_new_process_group(group_size):
+ """
+ Creates process groups of a gives `group_size` and returns
+ process group that current GPU participates in.
+
+ `group_size` must divide the total number of GPUs (world_size).
+
+ Modified from
+ https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
+
+ Args:
+ group_size (int): number of GPU's to collaborate for sync bn
+ """
+
+ assert group_size > 0
+
+ world_size = torch.distributed.get_world_size()
+ if world_size <= 8:
+ if group_size > world_size:
+ logging.warning(
+ f"Requested group size [{group_size}] > world size [{world_size}]. "
+ "Assuming local debug run and capping it to world size."
+ )
+ group_size = world_size
+ assert world_size >= group_size
+ assert world_size % group_size == 0
+
+ group = None
+ for group_num in range(world_size // group_size):
+ group_ids = range(group_num * group_size, (group_num + 1) * group_size)
+ cur_group = torch.distributed.new_group(ranks=group_ids)
+ if torch.distributed.get_rank() // group_size == group_num:
+ group = cur_group
+ # can not drop out and return here, every process must go through creation of all subgroups
+
+ assert group is not None
+ return group
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
diff --git a/phantom/submodules/sam2/training/utils/logger.py b/phantom/submodules/sam2/training/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4b4ef0ebe359063e1ca2c3a46cb8fcc76d067c2
--- /dev/null
+++ b/phantom/submodules/sam2/training/utils/logger.py
@@ -0,0 +1,246 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py
+import atexit
+import functools
+import logging
+import sys
+import uuid
+from typing import Any, Dict, Optional, Union
+
+from hydra.utils import instantiate
+
+from iopath.common.file_io import g_pathmgr
+from numpy import ndarray
+from torch import Tensor
+from torch.utils.tensorboard import SummaryWriter
+
+from training.utils.train_utils import get_machine_local_and_dist_rank, makedir
+
+Scalar = Union[Tensor, ndarray, int, float]
+
+
+def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any):
+ makedir(log_dir)
+ summary_writer_method = SummaryWriter
+ return TensorBoardLogger(
+ path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs
+ )
+
+
+class TensorBoardWriterWrapper:
+ """
+ A wrapper around a SummaryWriter object.
+ """
+
+ def __init__(
+ self,
+ path: str,
+ *args: Any,
+ filename_suffix: str = None,
+ summary_writer_method: Any = SummaryWriter,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TensorBoard logger.
+ On construction, the logger creates a new events file that logs
+ will be written to. If the environment variable `RANK` is defined,
+ logger will only log if RANK = 0.
+
+ NOTE: If using the logger with distributed training:
+ - This logger can call collective operations
+ - Logs will be written on rank 0 only
+ - Logger must be constructed synchronously *after* initializing distributed process group.
+
+ Args:
+ path (str): path to write logs to
+ *args, **kwargs: Extra arguments to pass to SummaryWriter
+ """
+ self._writer: Optional[SummaryWriter] = None
+ _, self._rank = get_machine_local_and_dist_rank()
+ self._path: str = path
+ if self._rank == 0:
+ logging.info(
+ f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
+ )
+ self._writer = summary_writer_method(
+ log_dir=path,
+ *args,
+ filename_suffix=filename_suffix or str(uuid.uuid4()),
+ **kwargs,
+ )
+ else:
+ logging.debug(
+ f"Not logging meters on this host because env RANK: {self._rank} != 0"
+ )
+ atexit.register(self.close)
+
+ @property
+ def writer(self) -> Optional[SummaryWriter]:
+ return self._writer
+
+ @property
+ def path(self) -> str:
+ return self._path
+
+ def flush(self) -> None:
+ """Writes pending logs to disk."""
+
+ if not self._writer:
+ return
+
+ self._writer.flush()
+
+ def close(self) -> None:
+ """Close writer, flushing pending logs to disk.
+ Logs cannot be written after `close` is called.
+ """
+
+ if not self._writer:
+ return
+
+ self._writer.close()
+ self._writer = None
+
+
+class TensorBoardLogger(TensorBoardWriterWrapper):
+ """
+ A simple logger for TensorBoard.
+ """
+
+ def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
+ """Add multiple scalar values to TensorBoard.
+
+ Args:
+ payload (dict): dictionary of tag name and scalar value
+ step (int, Optional): step value to record
+ """
+ if not self._writer:
+ return
+ for k, v in payload.items():
+ self.log(k, v, step)
+
+ def log(self, name: str, data: Scalar, step: int) -> None:
+ """Add scalar data to TensorBoard.
+
+ Args:
+ name (string): tag name used to group scalars
+ data (float/int/Tensor): scalar data to log
+ step (int, optional): step value to record
+ """
+ if not self._writer:
+ return
+ self._writer.add_scalar(name, data, global_step=step, new_style=True)
+
+ def log_hparams(
+ self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
+ ) -> None:
+ """Add hyperparameter data to TensorBoard.
+
+ Args:
+ hparams (dict): dictionary of hyperparameter names and corresponding values
+ meters (dict): dictionary of name of meter and corersponding values
+ """
+ if not self._writer:
+ return
+ self._writer.add_hparams(hparams, meters)
+
+
+class Logger:
+ """
+ A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger.
+ """
+
+ def __init__(self, logging_conf):
+ # allow turning off TensorBoard with "should_log: false" in config
+ tb_config = logging_conf.tensorboard_writer
+ tb_should_log = tb_config and tb_config.pop("should_log", True)
+ self.tb_logger = instantiate(tb_config) if tb_should_log else None
+
+ def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
+ if self.tb_logger:
+ self.tb_logger.log_dict(payload, step)
+
+ def log(self, name: str, data: Scalar, step: int) -> None:
+ if self.tb_logger:
+ self.tb_logger.log(name, data, step)
+
+ def log_hparams(
+ self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
+ ) -> None:
+ if self.tb_logger:
+ self.tb_logger.log_hparams(hparams, meters)
+
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+ # we tune the buffering value so that the logs are updated
+ # frequently.
+ log_buffer_kb = 10 * 1024 # 10KB
+ io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb)
+ atexit.register(io.close)
+ return io
+
+
+def setup_logging(
+ name,
+ output_dir=None,
+ rank=0,
+ log_level_primary="INFO",
+ log_level_secondary="ERROR",
+):
+ """
+ Setup various logging streams: stdout and file handlers.
+ For file handlers, we only setup for the master gpu.
+ """
+ # get the filename if we want to log to the file as well
+ log_filename = None
+ if output_dir:
+ makedir(output_dir)
+ if rank == 0:
+ log_filename = f"{output_dir}/log.txt"
+
+ logger = logging.getLogger(name)
+ logger.setLevel(log_level_primary)
+
+ # create formatter
+ FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
+ formatter = logging.Formatter(FORMAT)
+
+ # Cleanup any existing handlers
+ for h in logger.handlers:
+ logger.removeHandler(h)
+ logger.root.handlers = []
+
+ # setup the console handler
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setFormatter(formatter)
+ logger.addHandler(console_handler)
+ if rank == 0:
+ console_handler.setLevel(log_level_primary)
+ else:
+ console_handler.setLevel(log_level_secondary)
+
+ # we log to file as well if user wants
+ if log_filename and rank == 0:
+ file_handler = logging.StreamHandler(_cached_log_stream(log_filename))
+ file_handler.setLevel(log_level_primary)
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ logging.root = logger
+
+
+def shutdown_logging():
+ """
+ After training is done, we ensure to shut down all the logger streams.
+ """
+ logging.info("Shutting down loggers...")
+ handlers = logging.root.handlers
+ for handler in handlers:
+ handler.close()
diff --git a/phantom/submodules/sam2/training/utils/train_utils.py b/phantom/submodules/sam2/training/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..91d5577d5f50c81624737d221dc572ac3c4cee56
--- /dev/null
+++ b/phantom/submodules/sam2/training/utils/train_utils.py
@@ -0,0 +1,288 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+import os
+import random
+import re
+from datetime import timedelta
+from typing import Optional
+
+import hydra
+
+import numpy as np
+import omegaconf
+import torch
+import torch.distributed as dist
+from iopath.common.file_io import g_pathmgr
+from omegaconf import OmegaConf
+
+
+def multiply_all(*args):
+ return np.prod(np.array(args)).item()
+
+
+def collect_dict_keys(config):
+ """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined"""
+ val_keys = []
+ # If the this config points to the collate function, then it has a key
+ if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]):
+ val_keys.append(config["dict_key"])
+ else:
+ # Recursively proceed
+ for v in config.values():
+ if isinstance(v, type(config)):
+ val_keys.extend(collect_dict_keys(v))
+ elif isinstance(v, omegaconf.listconfig.ListConfig):
+ for item in v:
+ if isinstance(item, type(config)):
+ val_keys.extend(collect_dict_keys(item))
+ return val_keys
+
+
+class Phase:
+ TRAIN = "train"
+ VAL = "val"
+
+
+def register_omegaconf_resolvers():
+ OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
+ OmegaConf.register_new_resolver("get_class", hydra.utils.get_class)
+ OmegaConf.register_new_resolver("add", lambda x, y: x + y)
+ OmegaConf.register_new_resolver("times", multiply_all)
+ OmegaConf.register_new_resolver("divide", lambda x, y: x / y)
+ OmegaConf.register_new_resolver("pow", lambda x, y: x**y)
+ OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
+ OmegaConf.register_new_resolver("range", lambda x: list(range(x)))
+ OmegaConf.register_new_resolver("int", lambda x: int(x))
+ OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x)))
+ OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x))
+
+
+def setup_distributed_backend(backend, timeout_mins):
+ """
+ Initialize torch.distributed and set the CUDA device.
+ Expects environment variables to be set as per
+ https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
+ along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
+ """
+ # enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins
+ # of waiting
+ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
+ logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins")
+ dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins))
+ return dist.get_rank()
+
+
+def get_machine_local_and_dist_rank():
+ """
+ Get the distributed and local rank of the current gpu.
+ """
+ local_rank = int(os.environ.get("LOCAL_RANK", None))
+ distributed_rank = int(os.environ.get("RANK", None))
+ assert (
+ local_rank is not None and distributed_rank is not None
+ ), "Please the set the RANK and LOCAL_RANK environment variables."
+ return local_rank, distributed_rank
+
+
+def print_cfg(cfg):
+ """
+ Supports printing both Hydra DictConfig and also the AttrDict config
+ """
+ logging.info("Training with config:")
+ logging.info(OmegaConf.to_yaml(cfg))
+
+
+def set_seeds(seed_value, max_epochs, dist_rank):
+ """
+ Set the python random, numpy and torch seed for each gpu. Also set the CUDA
+ seeds if the CUDA is available. This ensures deterministic nature of the training.
+ """
+ # Since in the pytorch sampler, we increment the seed by 1 for every epoch.
+ seed_value = (seed_value + dist_rank) * max_epochs
+ logging.info(f"MACHINE SEED: {seed_value}")
+ random.seed(seed_value)
+ np.random.seed(seed_value)
+ torch.manual_seed(seed_value)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed_value)
+
+
+def makedir(dir_path):
+ """
+ Create the directory if it does not exist.
+ """
+ is_success = False
+ try:
+ if not g_pathmgr.exists(dir_path):
+ g_pathmgr.mkdirs(dir_path)
+ is_success = True
+ except BaseException:
+ logging.info(f"Error creating directory: {dir_path}")
+ return is_success
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_amp_type(amp_type: Optional[str] = None):
+ if amp_type is None:
+ return None
+ assert amp_type in ["bfloat16", "float16"], "Invalid Amp type."
+ if amp_type == "bfloat16":
+ return torch.bfloat16
+ else:
+ return torch.float16
+
+
+def log_env_variables():
+ env_keys = sorted(list(os.environ.keys()))
+ st = ""
+ for k in env_keys:
+ v = os.environ[k]
+ st += f"{k}={v}\n"
+ logging.info("Logging ENV_VARIABLES")
+ logging.info(st)
+
+
+class AverageMeter:
+ """Computes and stores the average and current value"""
+
+ def __init__(self, name, device, fmt=":f"):
+ self.name = name
+ self.fmt = fmt
+ self.device = device
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+ self._allow_updates = True
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
+ return fmtstr.format(**self.__dict__)
+
+
+class MemMeter:
+ """Computes and stores the current, avg, and max of peak Mem usage per iteration"""
+
+ def __init__(self, name, device, fmt=":f"):
+ self.name = name
+ self.fmt = fmt
+ self.device = device
+ self.reset()
+
+ def reset(self):
+ self.val = 0 # Per iteration max usage
+ self.avg = 0 # Avg per iteration max usage
+ self.peak = 0 # Peak usage for lifetime of program
+ self.sum = 0
+ self.count = 0
+ self._allow_updates = True
+
+ def update(self, n=1, reset_peak_usage=True):
+ self.val = torch.cuda.max_memory_allocated() // 1e9
+ self.sum += self.val * n
+ self.count += n
+ self.avg = self.sum / self.count
+ self.peak = max(self.peak, self.val)
+ if reset_peak_usage:
+ torch.cuda.reset_peak_memory_stats()
+
+ def __str__(self):
+ fmtstr = (
+ "{name}: {val"
+ + self.fmt
+ + "} ({avg"
+ + self.fmt
+ + "}/{peak"
+ + self.fmt
+ + "})"
+ )
+ return fmtstr.format(**self.__dict__)
+
+
+def human_readable_time(time_seconds):
+ time = int(time_seconds)
+ minutes, seconds = divmod(time, 60)
+ hours, minutes = divmod(minutes, 60)
+ days, hours = divmod(hours, 24)
+ return f"{days:02}d {hours:02}h {minutes:02}m"
+
+
+class DurationMeter:
+ def __init__(self, name, device, fmt=":f"):
+ self.name = name
+ self.device = device
+ self.fmt = fmt
+ self.val = 0
+
+ def reset(self):
+ self.val = 0
+
+ def update(self, val):
+ self.val = val
+
+ def add(self, val):
+ self.val += val
+
+ def __str__(self):
+ return f"{self.name}: {human_readable_time(self.val)}"
+
+
+class ProgressMeter:
+ def __init__(self, num_batches, meters, real_meters, prefix=""):
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+ self.meters = meters
+ self.real_meters = real_meters
+ self.prefix = prefix
+
+ def display(self, batch, enable_print=False):
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
+ entries += [str(meter) for meter in self.meters]
+ entries += [
+ " | ".join(
+ [
+ f"{os.path.join(name, subname)}: {val:.4f}"
+ for subname, val in meter.compute().items()
+ ]
+ )
+ for name, meter in self.real_meters.items()
+ ]
+ logging.info(" | ".join(entries))
+ if enable_print:
+ print(" | ".join(entries))
+
+ def _get_batch_fmtstr(self, num_batches):
+ num_digits = len(str(num_batches // 1))
+ fmt = "{:" + str(num_digits) + "d}"
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
+
+
+def get_resume_checkpoint(checkpoint_save_dir):
+ if not g_pathmgr.isdir(checkpoint_save_dir):
+ return None
+ ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
+ if not g_pathmgr.isfile(ckpt_file):
+ return None
+
+ return ckpt_file
diff --git a/requirements.txt b/requirements.txt
index 97ff04d2ffb49bebfbd8a0f1ae3f7e9763110c51..4a39351a69e39e7bd88910dac47bbd7334b9ec9d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,31 +1,32 @@
-# Gradio和Spaces
-gradio==4.44.0
-spaces==0.28.3
+# Gradio 和 Spaces
+gradio>=4.44.0
+spaces>=0.28.3
-# PyTorch (HF已预装,但指定版本)
+# PyTorch (使用 CUDA 12.1 版本)
+--extra-index-url https://download.pytorch.org/whl/cu121
torch==2.1.0
torchvision==0.16.0
# 基础科学计算
numpy==1.26.4
opencv-python==4.8.1.78
-pillow==10.1.0
-scipy==1.11.4
-scikit-learn==1.3.2
+pillow>=10.1.0
+scipy>=1.11.4
+scikit-learn>=1.3.2
# 配置管理
hydra-core==1.3.2
omegaconf==2.3.0
# 视频处理
-mediapy==1.1.9
+mediapy>=1.1.9
-# 3D处理
-open3d==0.18.0
+# 3D 处理
+open3d>=0.18.0
# 机器学习
transformers==4.42.4
-timm==0.9.12
+timm>=0.9.12
# 其他工具
joblib
@@ -35,5 +36,5 @@ Rtree
protobuf==3.20.0
gdown
-# MMCV (Phantom需要)
+# MMCV (将在 setup.sh 中安装 full 版本)
mmcv==1.3.9
diff --git a/setup.sh b/setup.sh
index 19f79b00c8c592dff77f2b0bd272b0cfe71336f7..00965cf30c2ac9e1b8ecde5c64dd8fd0619ada0f 100644
--- a/setup.sh
+++ b/setup.sh
@@ -1,127 +1,156 @@
#!/bin/bash
-# Phantom环境配置脚本
-# 在app.py启动时运行(仅首次)
+# Phantom HuggingFace Spaces 安装脚本
+# 仅 Inference 模式 - 跳过 training 相关依赖
set -e
PHANTOM_DIR="/home/user/app/phantom"
LOG_FILE="/tmp/phantom_setup.log"
-# 日志函数
log() {
echo "[$(date +'%H:%M:%S')] $1" | tee -a "$LOG_FILE"
}
-log "🚀 开始配置Phantom环境"
+log "🚀 开始配置 Phantom 环境 (Inference Only)"
-# 检查phantom目录
+# 检查 phantom 目录
if [ ! -d "$PHANTOM_DIR" ]; then
- log "❌ Phantom目录不存在"
+ log "❌ Phantom 目录不存在"
exit 1
fi
cd "$PHANTOM_DIR"
-# ========== 安装子模块 ==========
+# ========== 安装 Inference 必需依赖 ==========
-# 1. SAM2
+# 1. 安装 PyTorch (如果尚未安装)
+if ! python -c "import torch" 2>/dev/null; then
+ log "📦 安装 PyTorch..."
+ pip install -q torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
+fi
+
+# 2. SAM2 (分割模型)
if [ ! -f "/tmp/.sam2_installed" ]; then
- log "📦 安装SAM2..."
- cd submodules/sam2
- pip install -q -e . 2>&1 | tee -a "$LOG_FILE" || log "⚠️ SAM2安装出现警告"
+ log "📦 安装 SAM2..."
+ cd "$PHANTOM_DIR/submodules/sam2"
+ pip install -q -e . 2>&1 | tee -a "$LOG_FILE" || log "⚠️ SAM2 警告"
touch /tmp/.sam2_installed
- log "✅ SAM2完成"
+ log "✅ SAM2 完成"
fi
-# 2. HaMeR
+# 3. HaMeR (手部姿态估计)
if [ ! -f "/tmp/.hamer_installed" ]; then
- log "📦 安装HaMeR..."
+ log "📦 安装 HaMeR..."
cd "$PHANTOM_DIR/submodules/phantom-hamer"
- pip install -q -e .[all] 2>&1 | tee -a "$LOG_FILE" || log "⚠️ HaMeR安装出现警告"
-
- # 安装ViTPose
+ pip install -q -e .[all] 2>&1 | tee -a "$LOG_FILE" || log "⚠️ HaMeR 警告"
+
+ # 安装 ViTPose
if [ -d "third-party/ViTPose" ]; then
- pip install -q -e third-party/ViTPose 2>&1 | tee -a "$LOG_FILE"
+ log "📦 安装 ViTPose..."
+ pip install -q -e third-party/ViTPose 2>&1 | tee -a "$LOG_FILE" || true
fi
-
- # 下载demo数据(如果不存在)
- if [ ! -d "_DATA/hamer_demo_data" ]; then
- log "📥 下载HaMeR demo数据..."
- cd _DATA
- wget -q https://www.cs.utexas.edu/~pavlakos/hamer/data/hamer_demo_data.tar.gz
- tar -xzf hamer_demo_data.tar.gz 2>&1 | tee -a "$LOG_FILE"
- rm hamer_demo_data.tar.gz
- fi
-
- touch /tmp/.hamer_installed
- log "✅ HaMeR完成"
-fi
-# 3. MMCV-Full
-if [ ! -f "/tmp/.mmcv_installed" ]; then
- log "📦 安装MMCV-Full..."
- pip install -q mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html 2>&1 | tee -a "$LOG_FILE"
- touch /tmp/.mmcv_installed
- log "✅ MMCV完成"
+ touch /tmp/.hamer_installed
+ log "✅ HaMeR 完成"
fi
-# 4. Robosuite
-if [ ! -f "/tmp/.robosuite_installed" ]; then
- log "📦 安装Robosuite..."
- cd "$PHANTOM_DIR/submodules/phantom-robosuite"
- pip install -q -e . 2>&1 | tee -a "$LOG_FILE"
- touch /tmp/.robosuite_installed
- log "✅ Robosuite完成"
+# 4. 下载 HaMeR demo 数据
+if [ ! -d "$PHANTOM_DIR/submodules/phantom-hamer/_DATA/hamer_demo_data" ]; then
+ log "📥 下载 HaMeR demo 数据..."
+ cd "$PHANTOM_DIR/submodules/phantom-hamer"
+ mkdir -p _DATA && cd _DATA
+ if [ ! -f "hamer_demo_data.tar.gz" ]; then
+ wget -q https://www.cs.utexas.edu/~pavlakos/hamer/data/hamer_demo_data.tar.gz || log "⚠️ HaMeR 数据下载失败"
+ fi
+ if [ -f "hamer_demo_data.tar.gz" ]; then
+ tar --warning=no-unknown-keyword -xzf hamer_demo_data.tar.gz 2>&1 | tee -a "$LOG_FILE" || true
+ rm -f hamer_demo_data.tar.gz
+ log "✅ HaMeR 数据完成"
+ fi
fi
-# 5. Robomimic
-if [ ! -f "/tmp/.robomimic_installed" ]; then
- log "📦 安装Robomimic..."
- cd "$PHANTOM_DIR/submodules/phantom-robomimic"
- pip install -q -e . 2>&1 | tee -a "$LOG_FILE"
- touch /tmp/.robomimic_installed
- log "✅ Robomimic完成"
+# 5. MMCV (仅基础版本,inference 够用)
+if [ ! -f "/tmp/.mmcv_installed" ]; then
+ log "📦 安装 MMCV..."
+ pip install -q mmcv==1.3.9 2>&1 | tee -a "$LOG_FILE" || true
+ # 尝试安装 mmcv-full,失败也没关系
+ pip install -q mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html 2>&1 | tee -a "$LOG_FILE" || log "⚠️ MMCV-full 跳过,使用基础版本"
+ touch /tmp/.mmcv_installed
+ log "✅ MMCV 完成"
fi
-# 6. E2FGVI权重
-if [ ! -f "/tmp/.e2fgvi_weights" ]; then
- log "📥 下载E2FGVI权重..."
- cd "$PHANTOM_DIR/submodules/phantom-E2FGVI/E2FGVI/release_model"
- if [ ! -f "E2FGVI-HQ.pth" ]; then
- gdown --fuzzy https://drive.google.com/file/d/10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3/view?usp=sharing 2>&1 | tee -a "$LOG_FILE"
- fi
- touch /tmp/.e2fgvi_weights
- log "✅ E2FGVI权重完成"
+# 6. E2FGVI (视频修复)
+E2FGVI_DIR="$PHANTOM_DIR/submodules/phantom-E2FGVI/E2FGVI/release_model"
+if [ ! -f "$E2FGVI_DIR/E2FGVI-HQ.pth" ]; then
+ log "📥 下载 E2FGVI 权重..."
+ mkdir -p "$E2FGVI_DIR"
+ cd "$E2FGVI_DIR"
+ pip install -q gdown
+ gdown --fuzzy "https://drive.google.com/file/d/10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3/view?usp=sharing" 2>&1 | tee -a "$LOG_FILE" || log "⚠️ E2FGVI 权重下载失败"
+ log "✅ E2FGVI 权重完成"
fi
-# 7. E2FGVI
if [ ! -f "/tmp/.e2fgvi_installed" ]; then
- log "📦 安装E2FGVI..."
+ log "📦 安装 E2FGVI..."
cd "$PHANTOM_DIR/submodules/phantom-E2FGVI"
pip install -q -e . 2>&1 | tee -a "$LOG_FILE"
touch /tmp/.e2fgvi_installed
- log "✅ E2FGVI完成"
+ log "✅ E2FGVI 完成"
fi
-# 8. Phantom主包
-if [ ! -f "/tmp/.phantom_installed" ]; then
- log "📦 安装Phantom主包..."
+# ========== 跳过 Training 依赖 ==========
+# 以下包仅用于训练,inference 不需要:
+# - phantom-robosuite (机器人仿真)
+# - phantom-robomimic (机器人学习)
+log "⏭️ 跳过 Training 依赖 (robosuite, robomimic)"
+
+# 7. 其他 inference 依赖
+log "📦 安装其他依赖..."
+pip install -q joblib mediapy 2>&1 | tee -a "$LOG_FILE" || true
+pip install -q transformers==4.42.4 2>&1 | tee -a "$LOG_FILE" || true
+pip install -q PyOpenGL==3.1.4 Rtree protobuf==3.20.0 2>&1 | tee -a "$LOG_FILE" || true
+pip install -q hydra-core==1.3.2 omegaconf==2.3.0 2>&1 | tee -a "$LOG_FILE" || true
+pip install -q numpy==1.26.4 2>&1 | tee -a "$LOG_FILE" || true
+# open3d 体积大,尝试安装但不强求
+pip install -q open3d 2>&1 | tee -a "$LOG_FILE" || log "⚠️ open3d 跳过"
+
+# 8. Phantom 主包
+if [ ! -f "/tmp/.phantom_pkg_installed" ]; then
+ log "📦 安装 Phantom 主包..."
cd "$PHANTOM_DIR"
pip install -q -e . 2>&1 | tee -a "$LOG_FILE"
- touch /tmp/.phantom_installed
- log "✅ Phantom主包完成"
+ touch /tmp/.phantom_pkg_installed
+ log "✅ Phantom 主包完成"
fi
-# 9. 验证MANO模型
+# 9. 下载示例数据(可选)
+SAMPLE_DATA_DIR="$PHANTOM_DIR/data/raw"
+if [ ! -d "$SAMPLE_DATA_DIR/pick_and_place" ]; then
+ log "📥 下载示例数据..."
+ mkdir -p "$SAMPLE_DATA_DIR"
+ cd "$SAMPLE_DATA_DIR"
+ wget -q https://download.cs.stanford.edu/juno/phantom/pick_and_place.zip || log "⚠️ 示例数据下载失败"
+ if [ -f "pick_and_place.zip" ]; then
+ unzip -q pick_and_place.zip
+ rm -f pick_and_place.zip
+ log "✅ 示例数据完成"
+ fi
+fi
+
+# 10. 检查 MANO 模型
MANO_DIR="$PHANTOM_DIR/submodules/phantom-hamer/_DATA/data/mano"
+mkdir -p "$MANO_DIR"
+
+# 检查是否已存在(可能用户已经放在仓库里了)
if [ -f "$MANO_DIR/MANO_LEFT.pkl" ] && [ -f "$MANO_DIR/MANO_RIGHT.pkl" ]; then
- log "✅ MANO模型已就绪"
+ log "✅ MANO 模型已就绪"
else
- log "⚠️ MANO模型缺失,请上传"
+ log "⚠️ MANO 模型缺失!"
+ log " 请将文件放到: $MANO_DIR"
fi
-log "🎉 Phantom环境配置完成"
-log "日志文件: $LOG_FILE"
-
# 标记完成
touch /tmp/.phantom_ready
+
+log "🎉 Phantom 环境配置完成 (Inference Only)"
+log "📝 日志文件: $LOG_FILE"
diff --git a/your b/your
new file mode 100644
index 0000000000000000000000000000000000000000..35241f92ccf16369c0fbb565142156b2ff500732
--- /dev/null
+++ b/your
@@ -0,0 +1 @@
+export PATH="$HOME/.local/bin:$PATH" shell config file