lsnu commited on
Commit
f69c0bb
·
verified ·
1 Parent(s): 8944de3

Add files using upload-large-folder tool

Browse files
code/reveal_vla_bimanual/eval/ablations.py CHANGED
@@ -1,8 +1,7 @@
1
  MANDATORY_ABLATIONS: tuple[str, ...] = (
2
- "no_reveal_state_head",
3
  "no_world_model",
4
- "no_planner_reranking",
5
- "no_support_mode_conditioning",
6
- "no_wrist_cameras",
7
- "no_global_camera",
8
  )
 
1
  MANDATORY_ABLATIONS: tuple[str, ...] = (
2
+ "no_interaction_head",
3
  "no_world_model",
4
+ "no_planner",
5
+ "no_role_tokens",
6
+ "short_history",
 
7
  )
code/reveal_vla_bimanual/eval/metrics.py CHANGED
@@ -16,6 +16,14 @@ class BenchmarkMetrics:
16
  disturbance_cost: float | None = None
17
 
18
 
 
 
 
 
 
 
 
 
19
  def mean_success(per_task_success: dict[str, float]) -> float:
20
  if not per_task_success:
21
  return 0.0
@@ -50,3 +58,50 @@ def mean_disturbance_cost(values: np.ndarray) -> float:
50
  if values.size == 0:
51
  return 0.0
52
  return float(values.mean())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  disturbance_cost: float | None = None
17
 
18
 
19
+ @dataclass
20
+ class PlannerDiagnostics:
21
+ top1_accuracy: float
22
+ regret: float
23
+ risk_calibration_mse: float
24
+ role_collapse_rate: float
25
+
26
+
27
  def mean_success(per_task_success: dict[str, float]) -> float:
28
  if not per_task_success:
29
  return 0.0
 
58
  if values.size == 0:
59
  return 0.0
60
  return float(values.mean())
61
+
62
+
63
+ def planner_top1_accuracy(pred_scores: np.ndarray, oracle_utility: np.ndarray) -> float:
64
+ pred_scores = np.asarray(pred_scores)
65
+ oracle_utility = np.asarray(oracle_utility)
66
+ if pred_scores.size == 0:
67
+ return 0.0
68
+ return float((pred_scores.argmax(axis=-1) == oracle_utility.argmax(axis=-1)).mean())
69
+
70
+
71
+ def planner_regret(selected_indices: np.ndarray, oracle_utility: np.ndarray) -> float:
72
+ selected_indices = np.asarray(selected_indices, dtype=np.int64)
73
+ oracle_utility = np.asarray(oracle_utility, dtype=np.float32)
74
+ if oracle_utility.size == 0:
75
+ return 0.0
76
+ batch_index = np.arange(selected_indices.shape[0])
77
+ selected = oracle_utility[batch_index, selected_indices]
78
+ oracle = oracle_utility.max(axis=-1)
79
+ return float((oracle - selected).mean())
80
+
81
+
82
+ def risk_calibration_mse(predicted_risk: np.ndarray, realized_risk: np.ndarray) -> float:
83
+ predicted_risk = np.asarray(predicted_risk, dtype=np.float32)
84
+ realized_risk = np.asarray(realized_risk, dtype=np.float32)
85
+ if predicted_risk.size == 0:
86
+ return 0.0
87
+ return float(np.mean((predicted_risk - realized_risk) ** 2))
88
+
89
+
90
+ def role_collapse_rate(
91
+ action_chunks: np.ndarray,
92
+ arm_role_logits: np.ndarray | None = None,
93
+ action_threshold: float = 1e-2,
94
+ role_threshold: float = 0.1,
95
+ ) -> float:
96
+ action_chunks = np.asarray(action_chunks, dtype=np.float32)
97
+ right_actions = action_chunks[..., :7]
98
+ left_actions = action_chunks[..., 7:]
99
+ action_gap = np.mean(np.abs(right_actions - left_actions), axis=(-1, -2))
100
+ collapsed = action_gap <= action_threshold
101
+ if arm_role_logits is not None:
102
+ arm_role_logits = np.asarray(arm_role_logits, dtype=np.float32)
103
+ role_probs = np.exp(arm_role_logits - arm_role_logits.max(axis=-1, keepdims=True))
104
+ role_probs = role_probs / np.clip(role_probs.sum(axis=-1, keepdims=True), 1e-6, None)
105
+ role_gap = np.mean(np.abs(role_probs[..., 0, :] - role_probs[..., 1, :]), axis=-1)
106
+ collapsed = np.logical_or(collapsed, role_gap <= role_threshold)
107
+ return float(collapsed.mean())
code/reveal_vla_bimanual/eval/run_proxy_diagnostics.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.utils.data import DataLoader
12
+
13
+ from eval.metrics import planner_regret, planner_top1_accuracy, risk_calibration_mse, role_collapse_rate
14
+ from eval.run_reveal_benchmark import load_model
15
+ from sim_reveal.dataset import dataset_from_bundle, load_teacher_dataset
16
+
17
+
18
+ def _move_batch_to_device(batch: dict[str, Any], device: torch.device) -> dict[str, Any]:
19
+ moved = {}
20
+ for key, value in batch.items():
21
+ if isinstance(value, Tensor):
22
+ moved[key] = value.to(device)
23
+ else:
24
+ moved[key] = value
25
+ return moved
26
+
27
+
28
+ def main() -> None:
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--checkpoint", required=True)
31
+ parser.add_argument("--dataset", required=True)
32
+ parser.add_argument("--batch-size", type=int, default=8)
33
+ parser.add_argument("--num-workers", type=int, default=0)
34
+ parser.add_argument("--output-dir", required=True)
35
+ args = parser.parse_args()
36
+
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ model, _ = load_model(args.checkpoint, device=device)
39
+ bundle = load_teacher_dataset(args.dataset)
40
+ dataset = dataset_from_bundle(bundle, resolution=int(bundle["resolution"]))
41
+ loader = DataLoader(
42
+ dataset,
43
+ batch_size=args.batch_size,
44
+ shuffle=False,
45
+ num_workers=args.num_workers,
46
+ pin_memory=torch.cuda.is_available(),
47
+ )
48
+
49
+ score_batches: list[np.ndarray] = []
50
+ utility_batches: list[np.ndarray] = []
51
+ best_index_batches: list[np.ndarray] = []
52
+ risk_batches: list[np.ndarray] = []
53
+ realized_risk_batches: list[np.ndarray] = []
54
+ collapse_batches: list[float] = []
55
+
56
+ with torch.no_grad():
57
+ for batch in loader:
58
+ moved = _move_batch_to_device(batch, device)
59
+ outputs = model(
60
+ images=moved["images"],
61
+ proprio=moved["proprio"],
62
+ texts=moved["texts"],
63
+ history_images=moved.get("history_images"),
64
+ history_proprio=moved.get("history_proprio"),
65
+ plan=True,
66
+ candidate_chunks_override=moved["candidate_action_chunks"],
67
+ )
68
+ if "planner_scores" not in outputs:
69
+ raise RuntimeError("Planner outputs were not produced for proxy diagnostics.")
70
+ score_batches.append(outputs["planner_scores"].detach().cpu().numpy())
71
+ utility_batches.append(moved["candidate_utility"].detach().cpu().numpy())
72
+ best_index_batches.append(outputs["best_candidate_indices"].detach().cpu().numpy())
73
+ risk_batches.append(outputs["planner_risk_values"].detach().cpu().numpy())
74
+ realized_risk_batches.append(
75
+ torch.clamp(
76
+ moved["candidate_final_disturbance_cost"] + moved["candidate_reocclusion_rate"],
77
+ 0.0,
78
+ 1.0,
79
+ )
80
+ .detach()
81
+ .cpu()
82
+ .numpy()
83
+ )
84
+ selected_chunk = outputs["planned_chunk"].detach().cpu().numpy()[:, None]
85
+ role_logits = None
86
+ if outputs.get("interaction_state") is not None:
87
+ role_logits = outputs["interaction_state"]["arm_role_logits"].detach().cpu().numpy()[:, None]
88
+ collapse_batches.append(role_collapse_rate(selected_chunk, role_logits))
89
+
90
+ scores = np.concatenate(score_batches, axis=0) if score_batches else np.zeros((0, 0), dtype=np.float32)
91
+ utility = np.concatenate(utility_batches, axis=0) if utility_batches else np.zeros((0, 0), dtype=np.float32)
92
+ selected_indices = (
93
+ np.concatenate(best_index_batches, axis=0) if best_index_batches else np.zeros((0,), dtype=np.int64)
94
+ )
95
+ predicted_risk = np.concatenate(risk_batches, axis=0) if risk_batches else np.zeros((0, 0), dtype=np.float32)
96
+ realized_risk = (
97
+ np.concatenate(realized_risk_batches, axis=0) if realized_risk_batches else np.zeros((0, 0), dtype=np.float32)
98
+ )
99
+
100
+ diagnostics = {
101
+ "planner_top1_accuracy": planner_top1_accuracy(scores, utility),
102
+ "planner_regret": planner_regret(selected_indices, utility),
103
+ "risk_calibration_mse": risk_calibration_mse(predicted_risk, realized_risk),
104
+ "role_collapse_rate": float(np.mean(collapse_batches)) if collapse_batches else 0.0,
105
+ "num_samples": int(scores.shape[0]),
106
+ }
107
+
108
+ output_dir = Path(args.output_dir)
109
+ output_dir.mkdir(parents=True, exist_ok=True)
110
+ (output_dir / "proxy_diagnostics.json").write_text(json.dumps(diagnostics, indent=2), encoding="utf-8")
111
+ print(json.dumps(diagnostics, indent=2))
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
code/reveal_vla_bimanual/eval/run_reveal_benchmark.py CHANGED
@@ -49,7 +49,7 @@ def _trainer_config_from_dict(cfg: dict[str, Any]) -> TrainerConfig:
49
 
50
 
51
  def load_model(checkpoint_path: str | Path, device: torch.device) -> tuple[torch.nn.Module, dict[str, Any]]:
52
- checkpoint = torch.load(Path(checkpoint_path), map_location="cpu")
53
  policy_config = _policy_config_from_dict(checkpoint["policy_config"])
54
  trainer_config = _trainer_config_from_dict(checkpoint["trainer_config"])
55
  model = build_policy(policy_config, trainer_config).to(device)
@@ -112,6 +112,22 @@ def select_chunk(
112
  "proprio": batch["proprio"],
113
  "texts": batch["texts"],
114
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  if hasattr(model, "reveal_head"):
116
  if ablation == "no_world_model":
117
  outputs = model(**forward_kwargs, plan=False)
@@ -181,10 +197,13 @@ def evaluate_model(
181
  episode_visibility.append(float(privileged_state["visibility"]))
182
  episode_corridor.append(float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any()))
183
  episode_disturbance.append(float(privileged_state["disturbance_cost"]))
184
- if "reveal_state" in outputs and ablation != "no_reveal_state_head":
 
 
 
185
  persistence_errors.append(
186
  persistence_horizon_mae(
187
- outputs["reveal_state"]["persistence_horizon"][0].detach().cpu().numpy(),
188
  privileged_state["persistence_horizon"],
189
  )
190
  )
 
49
 
50
 
51
  def load_model(checkpoint_path: str | Path, device: torch.device) -> tuple[torch.nn.Module, dict[str, Any]]:
52
+ checkpoint = torch.load(Path(checkpoint_path), map_location="cpu", weights_only=False)
53
  policy_config = _policy_config_from_dict(checkpoint["policy_config"])
54
  trainer_config = _trainer_config_from_dict(checkpoint["trainer_config"])
55
  model = build_policy(policy_config, trainer_config).to(device)
 
112
  "proprio": batch["proprio"],
113
  "texts": batch["texts"],
114
  }
115
+ if hasattr(model, "interaction_head"):
116
+ outputs = model(
117
+ **forward_kwargs,
118
+ plan=(ablation not in {"no_world_model", "no_interaction_head"}),
119
+ support_mode_conditioning=True,
120
+ use_interaction_head=(ablation != "no_interaction_head"),
121
+ use_role_tokens=(ablation != "no_role_tokens"),
122
+ history_steps_override=(2 if ablation == "short_history" else None),
123
+ )
124
+ if ablation == "no_planner":
125
+ if "candidate_chunks" in outputs:
126
+ return outputs["candidate_chunks"][:, 0], outputs
127
+ return outputs["action_mean"], outputs
128
+ if "planned_chunk" in outputs and ablation not in {"no_world_model", "no_interaction_head"}:
129
+ return outputs["planned_chunk"], outputs
130
+ return outputs["action_mean"], outputs
131
  if hasattr(model, "reveal_head"):
132
  if ablation == "no_world_model":
133
  outputs = model(**forward_kwargs, plan=False)
 
197
  episode_visibility.append(float(privileged_state["visibility"]))
198
  episode_corridor.append(float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any()))
199
  episode_disturbance.append(float(privileged_state["disturbance_cost"]))
200
+ state_output = outputs.get("interaction_state")
201
+ if state_output is None:
202
+ state_output = outputs.get("reveal_state")
203
+ if state_output is not None and ablation != "no_interaction_head":
204
  persistence_errors.append(
205
  persistence_horizon_mae(
206
+ state_output["persistence_horizon"][0].detach().cpu().numpy(),
207
  privileged_state["persistence_horizon"],
208
  )
209
  )
code/reveal_vla_bimanual/eval/run_rlbench_rollout_eval.py CHANGED
@@ -7,13 +7,12 @@ from typing import Any, Sequence
7
 
8
  import numpy as np
9
  import torch
10
- from helpers.observation_utils import create_obs_config
11
- from omegaconf import OmegaConf
12
- from rlbench.action_modes.action_mode import BimanualMoveArmThenGripper
13
- from rlbench.action_modes.arm_action_modes import BimanualEndEffectorPoseViaPlanning
14
- from rlbench.action_modes.gripper_action_modes import BimanualDiscrete
15
  from rlbench.backend.utils import task_file_to_task_class
16
- from rlbench.environment import Environment
17
 
18
  from models.action_decoder import ChunkDecoderConfig
19
  from models.backbones import FrozenVLBackboneConfig
@@ -23,8 +22,13 @@ from models.planner import PlannerConfig
23
  from models.policy import PolicyConfig
24
  from models.reveal_head import RevealHeadConfig
25
  from models.world_model import RevealWMConfig
26
- from sim_rlbench.dataset import absolute_action_from_delta, bimanual_proprio_from_obs, stack_live_rgb_obs
27
- from train.trainer import TrainerConfig, build_policy
 
 
 
 
 
28
 
29
 
30
  def _policy_config_from_checkpoint(checkpoint: dict[str, Any]) -> PolicyConfig:
@@ -48,6 +52,19 @@ def _episode_language_goal(descriptions: Sequence[str]) -> str:
48
  return str(descriptions[0]) if descriptions else ""
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def main() -> None:
52
  parser = argparse.ArgumentParser()
53
  parser.add_argument("--checkpoint", required=True)
@@ -71,55 +88,57 @@ def main() -> None:
71
  model.load_state_dict(checkpoint["state_dict"], strict=True)
72
  model.eval()
73
  plan_requested = bool(args.plan)
74
- plan_applied = plan_requested and trainer_config.policy_type == "reveal_state"
75
  planning_note = None
76
- if plan_requested and trainer_config.policy_type != "reveal_state":
77
  plan_applied = False
78
  planning_note = "Planner requested for a backbone-only checkpoint; evaluating the backbone policy only."
79
- elif plan_requested and trainer_config.policy_type == "reveal_state" and not args.allow_unsupervised_planning:
80
- plan_applied = False
81
- planning_note = (
82
- "RLBench batches do not provide reveal supervision. Unsupervised reveal planning was disabled; "
83
- "use --allow-unsupervised-planning to override."
84
- )
85
-
86
- obs_config = create_obs_config(
87
- ["front", "wrist_left", "wrist_right"],
88
- [args.resolution, args.resolution],
89
- "BIMANUAL_PERACT",
90
- "bimanual",
91
- )
92
- action_mode = BimanualMoveArmThenGripper(
93
- BimanualEndEffectorPoseViaPlanning(absolute_mode=True, frame="world", collision_checking=False),
94
- BimanualDiscrete(),
95
- )
96
- env = Environment(
97
- action_mode=action_mode,
98
- obs_config=obs_config,
99
- headless=args.headless,
100
- robot_setup="dual_panda",
101
- )
102
 
103
  results: dict[str, Any] = {
104
  "checkpoint": str(Path(args.checkpoint).resolve()),
105
  "plan_requested": plan_requested,
106
  "plan_applied": plan_applied,
 
107
  "support_mode_conditioning": not args.disable_support_mode_conditioning,
108
  "episodes_per_task": args.episodes_per_task,
109
  "episode_length": args.episode_length,
110
  "resolution": args.resolution,
 
111
  "tasks": {},
112
  }
113
  if planning_note is not None:
114
  results["planning_note"] = planning_note
115
 
116
- try:
117
- env.launch()
118
- for task_name in args.tasks:
119
- task_class = task_file_to_task_class(task_name, bimanual=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  task = env.get_task(task_class)
121
- task_successes: list[float] = []
122
- task_returns: list[float] = []
123
  for _ in range(args.episodes_per_task):
124
  descriptions, obs = task.reset()
125
  language_goal = _episode_language_goal(descriptions)
@@ -157,7 +176,7 @@ def main() -> None:
157
  dtype=proprio.dtype,
158
  )
159
  with torch.no_grad():
160
- if trainer_config.policy_type == "reveal_state":
161
  outputs = model(
162
  images=images,
163
  proprio=proprio,
@@ -181,12 +200,12 @@ def main() -> None:
181
  step_action = chosen_chunk[0, 0].detach().float().cpu().numpy()
182
  if history_steps > 0:
183
  if len(history_images) >= history_steps:
184
- history_images = history_images[-history_steps + 1 :]
185
- history_proprio = history_proprio[-history_steps + 1 :]
 
186
  history_images.append(images[0].detach().cpu().numpy())
187
  history_proprio.append(proprio[0].detach().cpu().numpy())
188
- env_action = absolute_action_from_delta(obs, step_action, ignore_collisions=True)
189
- obs, reward, done = task.step(env_action)
190
  total_reward += float(reward)
191
  if reward >= 1.0:
192
  success = 1.0
@@ -195,13 +214,17 @@ def main() -> None:
195
  task_successes.append(success)
196
  task_returns.append(total_reward)
197
  results["tasks"][task_name] = {
 
198
  "successes": task_successes,
199
  "returns": task_returns,
200
  "mean_success": float(np.mean(task_successes)) if task_successes else 0.0,
201
  "mean_return": float(np.mean(task_returns)) if task_returns else 0.0,
202
  }
203
- finally:
204
- env.shutdown()
 
 
 
205
 
206
  task_scores = [task_data["mean_success"] for task_data in results["tasks"].values()]
207
  results["mean_success"] = float(np.mean(task_scores)) if task_scores else 0.0
@@ -222,7 +245,10 @@ def main() -> None:
222
  "",
223
  ]
224
  for task_name, task_data in results["tasks"].items():
225
- lines.append(f"- `{task_name}`: mean_success={task_data['mean_success']:.3f}, returns={task_data['returns']}")
 
 
 
226
  (output_dir / "rollout_eval.md").write_text("\n".join(lines) + "\n", encoding="utf-8")
227
  print(json.dumps(results, indent=2))
228
 
 
7
 
8
  import numpy as np
9
  import torch
10
+ from helpers.utils import create_obs_config
11
+ from rlbench.action_modes.action_mode import MoveArmThenGripper2Robots
12
+ from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaPlanning2Robots
13
+ from rlbench.action_modes.gripper_action_modes import Discrete2Robots
 
14
  from rlbench.backend.utils import task_file_to_task_class
15
+ from rlbench.environments_two_robots import Environment2Robots
16
 
17
  from models.action_decoder import ChunkDecoderConfig
18
  from models.backbones import FrozenVLBackboneConfig
 
22
  from models.policy import PolicyConfig
23
  from models.reveal_head import RevealHeadConfig
24
  from models.world_model import RevealWMConfig
25
+ from sim_rlbench.camera_spec import default_three_camera_spec
26
+ from sim_rlbench.dataset import (
27
+ bimanual_proprio_from_obs,
28
+ single_arm_absolute_action_from_delta,
29
+ stack_live_rgb_obs,
30
+ )
31
+ from train.trainer import TrainerConfig, build_policy, planner_enabled, policy_supports_planning
32
 
33
 
34
  def _policy_config_from_checkpoint(checkpoint: dict[str, Any]) -> PolicyConfig:
 
52
  return str(descriptions[0]) if descriptions else ""
53
 
54
 
55
+ def _step_bimanual_chunk(task: Any, obs: Any, delta_action: np.ndarray) -> tuple[Any, float, bool]:
56
+ total_reward = 0.0
57
+ done = False
58
+ next_obs = obs
59
+ for arm_name in ("right", "left"):
60
+ env_action = single_arm_absolute_action_from_delta(next_obs, delta_action, arm_name, ignore_collisions=True)
61
+ next_obs, reward, done = task.step(env_action, arm_name)
62
+ total_reward += float(reward)
63
+ if reward >= 1.0 or done:
64
+ break
65
+ return next_obs, total_reward, done
66
+
67
+
68
  def main() -> None:
69
  parser = argparse.ArgumentParser()
70
  parser.add_argument("--checkpoint", required=True)
 
88
  model.load_state_dict(checkpoint["state_dict"], strict=True)
89
  model.eval()
90
  plan_requested = bool(args.plan)
91
+ plan_applied = plan_requested and planner_enabled(trainer_config, during_eval=True)
92
  planning_note = None
93
+ if plan_requested and not policy_supports_planning(trainer_config.policy_type):
94
  plan_applied = False
95
  planning_note = "Planner requested for a backbone-only checkpoint; evaluating the backbone policy only."
96
+ elif plan_requested and trainer_config.planner_mode == "off":
97
+ planning_note = "Planner requested, but the checkpoint configuration sets planner_mode=off."
98
+ elif plan_requested and not args.allow_unsupervised_planning and trainer_config.planner_mode == "selfsup":
99
+ planning_note = "Planner is running in self-supervised mode without direct RLBench planner labels."
100
+
101
+ camera_spec = default_three_camera_spec(args.resolution)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  results: dict[str, Any] = {
104
  "checkpoint": str(Path(args.checkpoint).resolve()),
105
  "plan_requested": plan_requested,
106
  "plan_applied": plan_applied,
107
+ "planner_mode": trainer_config.planner_mode,
108
  "support_mode_conditioning": not args.disable_support_mode_conditioning,
109
  "episodes_per_task": args.episodes_per_task,
110
  "episode_length": args.episode_length,
111
  "resolution": args.resolution,
112
+ "cameras": list(camera_spec.cameras),
113
  "tasks": {},
114
  }
115
  if planning_note is not None:
116
  results["planning_note"] = planning_note
117
 
118
+ for task_name in args.tasks:
119
+ task_successes: list[float] = []
120
+ task_returns: list[float] = []
121
+ env: Environment2Robots | None = None
122
+ try:
123
+ task_class = task_file_to_task_class(task_name)
124
+ obs_config = create_obs_config(
125
+ list(camera_spec.upstream_cameras),
126
+ [args.resolution, args.resolution],
127
+ "PERACT_BC",
128
+ )
129
+ action_mode = MoveArmThenGripper2Robots(
130
+ EndEffectorPoseViaPlanning2Robots(absolute_mode=True, frame="world", collision_checking=False),
131
+ Discrete2Robots(),
132
+ )
133
+ env = Environment2Robots(
134
+ action_mode=action_mode,
135
+ obs_config=obs_config,
136
+ headless=args.headless,
137
+ robot_setup="panda",
138
+ task_name=task_class.__name__,
139
+ )
140
+ env.launch()
141
  task = env.get_task(task_class)
 
 
142
  for _ in range(args.episodes_per_task):
143
  descriptions, obs = task.reset()
144
  language_goal = _episode_language_goal(descriptions)
 
176
  dtype=proprio.dtype,
177
  )
178
  with torch.no_grad():
179
+ if policy_supports_planning(trainer_config.policy_type):
180
  outputs = model(
181
  images=images,
182
  proprio=proprio,
 
200
  step_action = chosen_chunk[0, 0].detach().float().cpu().numpy()
201
  if history_steps > 0:
202
  if len(history_images) >= history_steps:
203
+ keep = max(history_steps - 1, 0)
204
+ history_images = history_images[-keep:] if keep > 0 else []
205
+ history_proprio = history_proprio[-keep:] if keep > 0 else []
206
  history_images.append(images[0].detach().cpu().numpy())
207
  history_proprio.append(proprio[0].detach().cpu().numpy())
208
+ obs, reward, done = _step_bimanual_chunk(task, obs, step_action)
 
209
  total_reward += float(reward)
210
  if reward >= 1.0:
211
  success = 1.0
 
214
  task_successes.append(success)
215
  task_returns.append(total_reward)
216
  results["tasks"][task_name] = {
217
+ "task_class": task_class.__name__,
218
  "successes": task_successes,
219
  "returns": task_returns,
220
  "mean_success": float(np.mean(task_successes)) if task_successes else 0.0,
221
  "mean_return": float(np.mean(task_returns)) if task_returns else 0.0,
222
  }
223
+ except Exception as exc:
224
+ results["tasks"][task_name] = {"error": str(exc), "mean_success": 0.0, "mean_return": 0.0}
225
+ finally:
226
+ if env is not None:
227
+ env.shutdown()
228
 
229
  task_scores = [task_data["mean_success"] for task_data in results["tasks"].values()]
230
  results["mean_success"] = float(np.mean(task_scores)) if task_scores else 0.0
 
245
  "",
246
  ]
247
  for task_name, task_data in results["tasks"].items():
248
+ if "error" in task_data:
249
+ lines.append(f"- `{task_name}`: error={task_data['error']}")
250
+ else:
251
+ lines.append(f"- `{task_name}`: mean_success={task_data['mean_success']:.3f}, returns={task_data['returns']}")
252
  (output_dir / "rollout_eval.md").write_text("\n".join(lines) + "\n", encoding="utf-8")
253
  print(json.dumps(results, indent=2))
254
 
code/reveal_vla_bimanual/models/__init__.py CHANGED
@@ -1,10 +1,11 @@
1
- from models.action_decoder import ACTBimanualChunkDecoder, ChunkDecoderConfig
2
  from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
3
  from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
4
- from models.planner import PlannerConfig, RevealPlanner
5
- from models.policy import BackboneOnlyPolicy, RevealBimanualPolicy
6
- from models.reveal_head import RevealHeadConfig, RevealStateHead
7
- from models.world_model import RevealWM, RevealWMConfig
 
8
 
9
  __all__ = [
10
  "ACTBimanualChunkDecoder",
@@ -12,8 +13,16 @@ __all__ = [
12
  "ChunkDecoderConfig",
13
  "FrozenVLBackbone",
14
  "FrozenVLBackboneConfig",
 
 
 
 
 
 
15
  "MultiViewFusion",
16
  "MultiViewFusionConfig",
 
 
17
  "PlannerConfig",
18
  "RevealBimanualPolicy",
19
  "RevealHeadConfig",
 
1
+ from models.action_decoder import ACTBimanualChunkDecoder, ChunkDecoderConfig, InteractionChunkDecoder
2
  from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
3
  from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
4
+ from models.observation_memory import InteractionObservationMemory, ObservationMemory, ObservationMemoryConfig
5
+ from models.planner import InteractionPlanner, PlannerConfig, RevealPlanner
6
+ from models.policy import BackboneOnlyPolicy, InteractionBimanualPolicy, RevealBimanualPolicy
7
+ from models.reveal_head import InteractionStateHead, RevealHeadConfig, RevealStateHead
8
+ from models.world_model import InteractionRolloutModel, RevealWM, RevealWMConfig
9
 
10
  __all__ = [
11
  "ACTBimanualChunkDecoder",
 
13
  "ChunkDecoderConfig",
14
  "FrozenVLBackbone",
15
  "FrozenVLBackboneConfig",
16
+ "InteractionBimanualPolicy",
17
+ "InteractionChunkDecoder",
18
+ "InteractionObservationMemory",
19
+ "InteractionPlanner",
20
+ "InteractionRolloutModel",
21
+ "InteractionStateHead",
22
  "MultiViewFusion",
23
  "MultiViewFusionConfig",
24
+ "ObservationMemory",
25
+ "ObservationMemoryConfig",
26
  "PlannerConfig",
27
  "RevealBimanualPolicy",
28
  "RevealHeadConfig",
code/reveal_vla_bimanual/models/action_decoder.py CHANGED
@@ -17,6 +17,8 @@ class ChunkDecoderConfig:
17
  action_dim: int = 14
18
  arm_action_dim: int = 7
19
  num_candidates: int = 8
 
 
20
 
21
 
22
  class ACTBimanualChunkDecoder(nn.Module):
@@ -157,3 +159,225 @@ class ACTBimanualChunkDecoder(nn.Module):
157
  candidates = action_mean.unsqueeze(1) + noise * std.unsqueeze(1)
158
  candidates[:, 0] = action_mean
159
  return candidates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  action_dim: int = 14
18
  arm_action_dim: int = 7
19
  num_candidates: int = 8
20
+ num_phases: int = 5
21
+ num_arm_roles: int = 4
22
 
23
 
24
  class ACTBimanualChunkDecoder(nn.Module):
 
159
  candidates = action_mean.unsqueeze(1) + noise * std.unsqueeze(1)
160
  candidates[:, 0] = action_mean
161
  return candidates
162
+
163
+
164
+ class InteractionChunkDecoder(nn.Module):
165
+ def __init__(self, config: ChunkDecoderConfig) -> None:
166
+ super().__init__()
167
+ self.config = config
168
+ decoder_layer = nn.TransformerDecoderLayer(
169
+ d_model=config.hidden_dim,
170
+ nhead=config.num_heads,
171
+ dim_feedforward=config.ff_dim,
172
+ dropout=config.dropout,
173
+ batch_first=True,
174
+ norm_first=True,
175
+ )
176
+ self.right_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers)
177
+ left_layer = nn.TransformerDecoderLayer(
178
+ d_model=config.hidden_dim,
179
+ nhead=config.num_heads,
180
+ dim_feedforward=config.ff_dim,
181
+ dropout=config.dropout,
182
+ batch_first=True,
183
+ norm_first=True,
184
+ )
185
+ self.left_decoder = nn.TransformerDecoder(left_layer, num_layers=config.num_layers)
186
+ self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim)
187
+ self.proposal_queries = nn.Embedding(config.num_candidates, config.hidden_dim)
188
+ self.arm_identity = nn.Embedding(2, config.hidden_dim)
189
+ self.phase_adapter = nn.Linear(config.num_phases, config.hidden_dim)
190
+ self.role_adapter = nn.Linear(config.num_arm_roles, config.hidden_dim)
191
+ self.context_proj = nn.Sequential(
192
+ nn.LayerNorm(config.hidden_dim),
193
+ nn.Linear(config.hidden_dim, config.hidden_dim),
194
+ nn.GELU(),
195
+ )
196
+ self.coordination = nn.Sequential(
197
+ nn.LayerNorm(config.hidden_dim * 3),
198
+ nn.Linear(config.hidden_dim * 3, config.hidden_dim),
199
+ nn.GELU(),
200
+ nn.Linear(config.hidden_dim, config.hidden_dim),
201
+ )
202
+ self.right_mean = nn.Linear(config.hidden_dim, config.arm_action_dim)
203
+ self.right_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim)
204
+ self.left_mean = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim)
205
+ self.left_log_std = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim)
206
+ self.proposal_score = nn.Sequential(
207
+ nn.LayerNorm(config.hidden_dim * 3),
208
+ nn.Linear(config.hidden_dim * 3, config.hidden_dim),
209
+ nn.GELU(),
210
+ nn.Linear(config.hidden_dim, 1),
211
+ )
212
+
213
+ def _conditioning(
214
+ self,
215
+ interaction_state: dict[str, Tensor] | None,
216
+ batch_size: int,
217
+ device: torch.device,
218
+ dtype: torch.dtype,
219
+ ) -> tuple[Tensor, Tensor, Tensor | None]:
220
+ if interaction_state is None:
221
+ zero_phase = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype)
222
+ zero_roles = torch.zeros(batch_size, 2, self.config.hidden_dim, device=device, dtype=dtype)
223
+ return zero_phase, zero_roles, None
224
+ phase_probs = interaction_state["phase_logits"].softmax(dim=-1).to(dtype=dtype)
225
+ arm_role_probs = interaction_state["arm_role_logits"].softmax(dim=-1).to(dtype=dtype)
226
+ phase_context = self.phase_adapter(phase_probs)
227
+ role_context = self.role_adapter(arm_role_probs)
228
+ return phase_context, role_context, interaction_state.get("interaction_tokens")
229
+
230
+ def _decode_from_queries(
231
+ self,
232
+ queries: Tensor,
233
+ decoder_memory: Tensor,
234
+ phase_context: Tensor,
235
+ role_context: Tensor,
236
+ interaction_context: Tensor,
237
+ ) -> dict[str, Tensor]:
238
+ phase_bias = phase_context.unsqueeze(1)
239
+ right_queries = (
240
+ queries
241
+ + phase_bias
242
+ + role_context[:, 0].unsqueeze(1)
243
+ + self.arm_identity.weight[0].view(1, 1, -1).to(dtype=queries.dtype)
244
+ )
245
+ left_queries = (
246
+ queries
247
+ + phase_bias
248
+ + role_context[:, 1].unsqueeze(1)
249
+ + self.arm_identity.weight[1].view(1, 1, -1).to(dtype=queries.dtype)
250
+ )
251
+ right_tokens = self.right_decoder(right_queries, decoder_memory)
252
+ left_tokens = self.left_decoder(left_queries, torch.cat([decoder_memory, right_tokens], dim=1))
253
+ context = interaction_context.unsqueeze(1).expand(-1, queries.shape[1], -1)
254
+ coordination_input = torch.cat([right_tokens, left_tokens, context], dim=-1)
255
+ coordination = torch.tanh(self.coordination(coordination_input))
256
+ right_tokens = right_tokens + coordination
257
+ left_tokens = left_tokens + coordination
258
+ action_mean = torch.cat([self.right_mean(right_tokens), self.left_mean(left_tokens)], dim=-1)
259
+ action_log_std = torch.cat(
260
+ [self.right_log_std(right_tokens), self.left_log_std(left_tokens)],
261
+ dim=-1,
262
+ ).clamp(min=-5.0, max=2.0)
263
+ pooled_features = torch.cat(
264
+ [right_tokens.mean(dim=1), left_tokens.mean(dim=1), coordination.mean(dim=1)],
265
+ dim=-1,
266
+ )
267
+ return {
268
+ "right_tokens": right_tokens,
269
+ "left_tokens": left_tokens,
270
+ "coordination_tokens": coordination,
271
+ "decoded_tokens": torch.cat([right_tokens, left_tokens], dim=-1),
272
+ "action_mean": action_mean,
273
+ "action_log_std": action_log_std,
274
+ "proposal_score": self.proposal_score(pooled_features).squeeze(-1),
275
+ }
276
+
277
+ def forward(
278
+ self,
279
+ scene_tokens: Tensor,
280
+ interaction_state: dict[str, Tensor] | None = None,
281
+ memory_tokens: Tensor | None = None,
282
+ reveal_tokens: Tensor | None = None,
283
+ memory_token: Tensor | None = None,
284
+ ) -> dict[str, Tensor]:
285
+ if memory_tokens is None:
286
+ memory_tokens = memory_token
287
+ batch_size = scene_tokens.shape[0]
288
+ dtype = scene_tokens.dtype
289
+ phase_context, role_context, interaction_tokens = self._conditioning(
290
+ interaction_state=interaction_state,
291
+ batch_size=batch_size,
292
+ device=scene_tokens.device,
293
+ dtype=dtype,
294
+ )
295
+
296
+ decoder_memory = scene_tokens
297
+ if interaction_tokens is not None:
298
+ decoder_memory = torch.cat([decoder_memory, interaction_tokens], dim=1)
299
+ elif reveal_tokens is not None:
300
+ decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1)
301
+ if memory_tokens is not None:
302
+ decoder_memory = torch.cat([decoder_memory, memory_tokens], dim=1)
303
+
304
+ if interaction_tokens is not None and interaction_tokens.numel() > 0:
305
+ interaction_context = interaction_tokens.mean(dim=1)
306
+ elif reveal_tokens is not None and reveal_tokens.numel() > 0:
307
+ interaction_context = reveal_tokens.mean(dim=1)
308
+ else:
309
+ interaction_context = scene_tokens.mean(dim=1)
310
+ interaction_context = self.context_proj(interaction_context)
311
+
312
+ base_queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
313
+ decoded = self._decode_from_queries(
314
+ queries=base_queries,
315
+ decoder_memory=decoder_memory,
316
+ phase_context=phase_context,
317
+ role_context=role_context,
318
+ interaction_context=interaction_context,
319
+ )
320
+
321
+ num_candidates = self.config.num_candidates
322
+ proposal_bias = self.proposal_queries.weight.view(1, num_candidates, 1, -1).expand(
323
+ batch_size, -1, self.config.chunk_size, -1
324
+ )
325
+ candidate_queries = base_queries.unsqueeze(1) + proposal_bias
326
+ flat_queries = candidate_queries.reshape(batch_size * num_candidates, self.config.chunk_size, self.config.hidden_dim)
327
+ flat_memory = decoder_memory.unsqueeze(1).expand(-1, num_candidates, -1, -1).reshape(
328
+ batch_size * num_candidates, decoder_memory.shape[1], decoder_memory.shape[2]
329
+ )
330
+ flat_phase = phase_context.unsqueeze(1).expand(-1, num_candidates, -1).reshape(
331
+ batch_size * num_candidates, self.config.hidden_dim
332
+ )
333
+ flat_roles = role_context.unsqueeze(1).expand(-1, num_candidates, -1, -1).reshape(
334
+ batch_size * num_candidates, 2, self.config.hidden_dim
335
+ )
336
+ flat_context = interaction_context.unsqueeze(1).expand(-1, num_candidates, -1).reshape(
337
+ batch_size * num_candidates, self.config.hidden_dim
338
+ )
339
+ candidate_decoded = self._decode_from_queries(
340
+ queries=flat_queries,
341
+ decoder_memory=flat_memory,
342
+ phase_context=flat_phase,
343
+ role_context=flat_roles,
344
+ interaction_context=flat_context,
345
+ )
346
+
347
+ proposal_deltas = candidate_decoded["action_mean"].view(
348
+ batch_size,
349
+ num_candidates,
350
+ self.config.chunk_size,
351
+ self.config.action_dim,
352
+ )
353
+ proposal_logits = candidate_decoded["proposal_score"].view(batch_size, num_candidates)
354
+ proposal_candidates = decoded["action_mean"].unsqueeze(1) + 0.35 * torch.tanh(proposal_deltas)
355
+ proposal_candidates[:, 0] = decoded["action_mean"]
356
+ proposal_logits[:, 0] = decoded["proposal_score"]
357
+ decoded["proposal_candidates"] = proposal_candidates
358
+ decoded["proposal_logits"] = proposal_logits
359
+ return decoded
360
+
361
+ def sample_candidates(
362
+ self,
363
+ action_mean: Tensor,
364
+ action_log_std: Tensor,
365
+ num_candidates: int | None = None,
366
+ proposal_candidates: Tensor | None = None,
367
+ ) -> Tensor:
368
+ if proposal_candidates is not None:
369
+ return proposal_candidates
370
+ num_candidates = num_candidates or self.config.num_candidates
371
+ if num_candidates <= 1:
372
+ return action_mean.unsqueeze(1)
373
+ noise = torch.randn(
374
+ action_mean.size(0),
375
+ num_candidates,
376
+ action_mean.size(1),
377
+ action_mean.size(2),
378
+ device=action_mean.device,
379
+ dtype=action_mean.dtype,
380
+ )
381
+ candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1)
382
+ candidates[:, 0] = action_mean
383
+ return candidates
code/reveal_vla_bimanual/models/observation_memory.py CHANGED
@@ -12,6 +12,9 @@ class ObservationMemoryConfig:
12
  history_steps: int = 2
13
  num_layers: int = 1
14
  dropout: float = 0.1
 
 
 
15
 
16
 
17
  class ObservationMemory(nn.Module):
@@ -52,5 +55,86 @@ class ObservationMemory(nn.Module):
52
  "memory_sequence": memory_sequence,
53
  "memory_state": final_state,
54
  "memory_token": self.token_proj(final_state).unsqueeze(1),
 
55
  "memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(final_state)).squeeze(-1),
56
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  history_steps: int = 2
13
  num_layers: int = 1
14
  dropout: float = 0.1
15
+ memory_bank_size: int = 4
16
+ num_heads: int = 4
17
+ max_history_steps: int = 8
18
 
19
 
20
  class ObservationMemory(nn.Module):
 
55
  "memory_sequence": memory_sequence,
56
  "memory_state": final_state,
57
  "memory_token": self.token_proj(final_state).unsqueeze(1),
58
+ "memory_tokens": self.token_proj(final_state).unsqueeze(1),
59
  "memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(final_state)).squeeze(-1),
60
  }
61
+
62
+
63
+ class InteractionObservationMemory(nn.Module):
64
+ def __init__(self, config: ObservationMemoryConfig) -> None:
65
+ super().__init__()
66
+ self.config = config
67
+ encoder_layer = nn.TransformerEncoderLayer(
68
+ d_model=config.hidden_dim,
69
+ nhead=config.num_heads,
70
+ dim_feedforward=config.hidden_dim * 4,
71
+ dropout=config.dropout,
72
+ batch_first=True,
73
+ norm_first=True,
74
+ )
75
+ self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=max(1, config.num_layers))
76
+ self.position_embedding = nn.Parameter(
77
+ torch.randn(1, config.max_history_steps + 1, config.hidden_dim) * 0.02
78
+ )
79
+ self.bank_queries = nn.Parameter(torch.randn(config.memory_bank_size, config.hidden_dim) * 0.02)
80
+ self.bank_attention = nn.MultiheadAttention(
81
+ embed_dim=config.hidden_dim,
82
+ num_heads=config.num_heads,
83
+ dropout=config.dropout,
84
+ batch_first=True,
85
+ )
86
+ self.bank_mlp = nn.Sequential(
87
+ nn.LayerNorm(config.hidden_dim),
88
+ nn.Linear(config.hidden_dim, config.hidden_dim),
89
+ nn.GELU(),
90
+ nn.Linear(config.hidden_dim, config.hidden_dim),
91
+ )
92
+ self.token_proj = nn.Sequential(
93
+ nn.LayerNorm(config.hidden_dim),
94
+ nn.Linear(config.hidden_dim, config.hidden_dim),
95
+ nn.GELU(),
96
+ )
97
+ self.uncertainty_head = nn.Sequential(
98
+ nn.LayerNorm(config.hidden_dim),
99
+ nn.Linear(config.hidden_dim, 1),
100
+ )
101
+
102
+ def _truncate_history(self, history_scene_tokens: Tensor | None) -> Tensor | None:
103
+ if history_scene_tokens is None or history_scene_tokens.numel() == 0:
104
+ return history_scene_tokens
105
+ if history_scene_tokens.shape[1] <= self.config.history_steps:
106
+ return history_scene_tokens
107
+ return history_scene_tokens[:, -self.config.history_steps :]
108
+
109
+ def forward(
110
+ self,
111
+ scene_tokens: Tensor,
112
+ history_scene_tokens: Tensor | None = None,
113
+ ) -> dict[str, Tensor]:
114
+ pooled_current = scene_tokens.mean(dim=1, keepdim=True)
115
+ history_scene_tokens = self._truncate_history(history_scene_tokens)
116
+ if history_scene_tokens is not None and history_scene_tokens.numel() > 0:
117
+ history_pooled = history_scene_tokens.mean(dim=2)
118
+ sequence = torch.cat([history_pooled, pooled_current], dim=1)
119
+ else:
120
+ sequence = pooled_current
121
+
122
+ seq_len = sequence.shape[1]
123
+ if seq_len > self.position_embedding.shape[1]:
124
+ raise ValueError(
125
+ f"Sequence length {seq_len} exceeds configured maximum {self.position_embedding.shape[1]}"
126
+ )
127
+ encoded = self.sequence_encoder(sequence + self.position_embedding[:, :seq_len])
128
+ batch_size = encoded.shape[0]
129
+ queries = self.bank_queries.unsqueeze(0).expand(batch_size, -1, -1)
130
+ bank_tokens, _ = self.bank_attention(queries, encoded, encoded)
131
+ bank_tokens = bank_tokens + self.bank_mlp(bank_tokens)
132
+ projected_bank = self.token_proj(bank_tokens)
133
+ pooled_bank = projected_bank.mean(dim=1)
134
+ return {
135
+ "memory_sequence": encoded,
136
+ "memory_state": encoded[:, -1],
137
+ "memory_token": pooled_bank.unsqueeze(1),
138
+ "memory_tokens": projected_bank,
139
+ "memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_bank)).squeeze(-1),
140
+ }
code/reveal_vla_bimanual/models/planner.py CHANGED
@@ -11,6 +11,7 @@ class PlannerConfig:
11
  hidden_dim: int = 512
12
  num_candidates: int = 8
13
  action_dim: int = 14
 
14
  utility_margin: float = 0.1
15
  corridor_weight: float = 1.0
16
  persistence_weight: float = 0.5
@@ -19,6 +20,10 @@ class PlannerConfig:
19
  disturbance_weight: float = 0.75
20
  reocclusion_weight: float = 0.5
21
  visibility_weight: float = 0.25
 
 
 
 
22
 
23
 
24
  class RevealPlanner(nn.Module):
@@ -87,3 +92,113 @@ class RevealPlanner(nn.Module):
87
  "best_indices": best_idx,
88
  "best_chunk": candidate_chunks[batch_indices, best_idx],
89
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  hidden_dim: int = 512
12
  num_candidates: int = 8
13
  action_dim: int = 14
14
+ num_support_modes: int = 3
15
  utility_margin: float = 0.1
16
  corridor_weight: float = 1.0
17
  persistence_weight: float = 0.5
 
20
  disturbance_weight: float = 0.75
21
  reocclusion_weight: float = 0.5
22
  visibility_weight: float = 0.25
23
+ num_heads: int = 4
24
+ num_layers: int = 2
25
+ num_phases: int = 5
26
+ num_arm_roles: int = 4
27
 
28
 
29
  class RevealPlanner(nn.Module):
 
92
  "best_indices": best_idx,
93
  "best_chunk": candidate_chunks[batch_indices, best_idx],
94
  }
95
+
96
+
97
+ class InteractionPlanner(nn.Module):
98
+ def __init__(self, config: PlannerConfig) -> None:
99
+ super().__init__()
100
+ self.config = config
101
+ step_dim = (
102
+ config.action_dim
103
+ + config.num_phases
104
+ + (2 * config.num_arm_roles)
105
+ + config.num_support_modes
106
+ + 7
107
+ )
108
+ self.step_proj = nn.Sequential(
109
+ nn.LayerNorm(step_dim),
110
+ nn.Linear(step_dim, config.hidden_dim),
111
+ nn.GELU(),
112
+ )
113
+ encoder_layer = nn.TransformerEncoderLayer(
114
+ d_model=config.hidden_dim,
115
+ nhead=config.num_heads,
116
+ dim_feedforward=config.hidden_dim * 4,
117
+ batch_first=True,
118
+ norm_first=True,
119
+ )
120
+ self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
121
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_dim) * 0.02)
122
+ self.success_head = nn.Linear(config.hidden_dim, 1)
123
+ self.risk_head = nn.Linear(config.hidden_dim, 1)
124
+ self.score_head = nn.Linear(config.hidden_dim, 1)
125
+
126
+ def _mean_field(self, tensor: Tensor) -> Tensor:
127
+ return tensor.mean(dim=(-1, -2))
128
+
129
+ def summarize_trajectory(self, candidate_chunks: Tensor, rollout_state: dict[str, Tensor]) -> Tensor:
130
+ horizon = min(candidate_chunks.shape[2], rollout_state["phase_logits"].shape[2])
131
+ candidate_steps = candidate_chunks[:, :, :horizon]
132
+ phase_probs = rollout_state["phase_logits"][:, :, :horizon].softmax(dim=-1)
133
+ support_probs = rollout_state["support_mode_logits"][:, :, :horizon].softmax(dim=-1)
134
+ arm_role_probs = rollout_state["arm_role_logits"][:, :, :horizon].softmax(dim=-1).flatten(start_dim=-2)
135
+ target_mean = self._mean_field(rollout_state["target_field"][:, :, :horizon].sigmoid())
136
+ feasibility_mean = self._mean_field(rollout_state["actor_feasibility_field"][:, :, :horizon].sigmoid())
137
+ persistence_mean = self._mean_field(rollout_state["persistence_field"][:, :, :horizon])
138
+ risk_mean = self._mean_field(rollout_state["risk_field"][:, :, :horizon])
139
+ uncertainty_mean = self._mean_field(rollout_state["uncertainty_field"][:, :, :horizon])
140
+ role_gap = (
141
+ rollout_state["arm_role_logits"][:, :, :horizon, 0].softmax(dim=-1)
142
+ - rollout_state["arm_role_logits"][:, :, :horizon, 1].softmax(dim=-1)
143
+ ).abs().mean(dim=-1, keepdim=True)
144
+ return torch.cat(
145
+ [
146
+ candidate_steps,
147
+ phase_probs,
148
+ arm_role_probs,
149
+ support_probs,
150
+ target_mean,
151
+ feasibility_mean,
152
+ persistence_mean,
153
+ risk_mean,
154
+ uncertainty_mean,
155
+ role_gap,
156
+ ],
157
+ dim=-1,
158
+ )
159
+
160
+ def score_rollouts(
161
+ self,
162
+ rollout_state: dict[str, Tensor],
163
+ candidate_chunks: Tensor,
164
+ proposal_logits: Tensor | None = None,
165
+ ) -> dict[str, Tensor]:
166
+ features = self.summarize_trajectory(candidate_chunks, rollout_state)
167
+ batch_size, num_candidates, horizon, _ = features.shape
168
+ flat_features = features.view(batch_size * num_candidates, horizon, -1)
169
+ hidden_steps = self.step_proj(flat_features)
170
+ cls = self.cls_token.expand(batch_size * num_candidates, -1, -1)
171
+ encoded = self.sequence_encoder(torch.cat([cls, hidden_steps], dim=1))
172
+ pooled = encoded[:, 0]
173
+ success_logits = self.success_head(pooled).view(batch_size, num_candidates).squeeze(-1)
174
+ risk_values = torch.sigmoid(self.risk_head(pooled)).view(batch_size, num_candidates).squeeze(-1)
175
+ utility_scores = self.score_head(pooled).view(batch_size, num_candidates).squeeze(-1)
176
+ utility_scores = utility_scores + success_logits.sigmoid() - risk_values
177
+ if proposal_logits is not None and proposal_logits.shape == utility_scores.shape:
178
+ utility_scores = utility_scores + self.config.proposal_weight * proposal_logits.sigmoid()
179
+ return {
180
+ "planner_features": features.mean(dim=2),
181
+ "planner_hidden": pooled.view(batch_size, num_candidates, -1),
182
+ "success_logits": success_logits,
183
+ "risk_values": risk_values,
184
+ "utility_scores": utility_scores,
185
+ }
186
+
187
+ def select_best(
188
+ self,
189
+ candidate_chunks: Tensor,
190
+ rollout_state: dict[str, Tensor],
191
+ proposal_logits: Tensor | None = None,
192
+ ) -> dict[str, Tensor]:
193
+ outputs = self.score_rollouts(
194
+ rollout_state=rollout_state,
195
+ candidate_chunks=candidate_chunks,
196
+ proposal_logits=proposal_logits,
197
+ )
198
+ best_idx = outputs["utility_scores"].argmax(dim=-1)
199
+ batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
200
+ return {
201
+ **outputs,
202
+ "best_indices": best_idx,
203
+ "best_chunk": candidate_chunks[batch_indices, best_idx],
204
+ }
code/reveal_vla_bimanual/models/policy.py CHANGED
@@ -6,13 +6,13 @@ from typing import Sequence
6
  import torch
7
  from torch import Tensor, nn
8
 
9
- from models.action_decoder import ACTBimanualChunkDecoder, ChunkDecoderConfig
10
  from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
11
  from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
12
- from models.observation_memory import ObservationMemory, ObservationMemoryConfig
13
- from models.planner import PlannerConfig, RevealPlanner
14
- from models.reveal_head import RevealHeadConfig, RevealStateHead
15
- from models.world_model import RevealWM, RevealWMConfig
16
 
17
 
18
  @dataclass
@@ -204,3 +204,138 @@ class RevealBimanualPolicy(BackboneOnlyPolicy):
204
  outputs["planner_scores"] = selected["utility_scores"]
205
  outputs["best_candidate_indices"] = selected["best_indices"]
206
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import torch
7
  from torch import Tensor, nn
8
 
9
+ from models.action_decoder import ACTBimanualChunkDecoder, ChunkDecoderConfig, InteractionChunkDecoder
10
  from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
11
  from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
12
+ from models.observation_memory import InteractionObservationMemory, ObservationMemory, ObservationMemoryConfig
13
+ from models.planner import InteractionPlanner, PlannerConfig, RevealPlanner
14
+ from models.reveal_head import InteractionStateHead, RevealHeadConfig, RevealStateHead
15
+ from models.world_model import InteractionRolloutModel, RevealWM, RevealWMConfig
16
 
17
 
18
  @dataclass
 
204
  outputs["planner_scores"] = selected["utility_scores"]
205
  outputs["best_candidate_indices"] = selected["best_indices"]
206
  return outputs
207
+
208
+
209
+ class InteractionBimanualPolicy(BackboneOnlyPolicy):
210
+ def __init__(self, config: PolicyConfig) -> None:
211
+ super().__init__(config)
212
+ self.memory = InteractionObservationMemory(config.memory)
213
+ self.decoder = InteractionChunkDecoder(config.decoder)
214
+ self.interaction_head = InteractionStateHead(config.reveal_head)
215
+ self.world_model = InteractionRolloutModel(config.world_model)
216
+ self.planner = InteractionPlanner(config.planner)
217
+
218
+ def _tile_tensor(self, value: Tensor, num_candidates: int) -> Tensor:
219
+ return value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape(
220
+ value.shape[0] * num_candidates,
221
+ *value.shape[1:],
222
+ )
223
+
224
+ def _tile_state(self, state: dict[str, Tensor], num_candidates: int) -> dict[str, Tensor]:
225
+ return {key: self._tile_tensor(value, num_candidates) for key, value in state.items()}
226
+
227
+ def forward(
228
+ self,
229
+ images: Tensor,
230
+ proprio: Tensor,
231
+ texts: Sequence[str] | None = None,
232
+ language_tokens: dict[str, Tensor] | None = None,
233
+ history_images: Tensor | None = None,
234
+ history_proprio: Tensor | None = None,
235
+ plan: bool = True,
236
+ support_mode_conditioning: bool = True,
237
+ candidate_chunks_override: Tensor | None = None,
238
+ use_interaction_head: bool = True,
239
+ use_role_tokens: bool = True,
240
+ history_steps_override: int | None = None,
241
+ ) -> dict[str, Tensor]:
242
+ scene_tokens = self.encode_scene(images, proprio, texts=texts, language_tokens=language_tokens)
243
+ history_scene_tokens = self.encode_history(
244
+ history_images,
245
+ history_proprio,
246
+ texts=texts,
247
+ language_tokens=language_tokens,
248
+ )
249
+ if history_steps_override is not None and history_scene_tokens is not None and history_scene_tokens.numel() > 0:
250
+ history_scene_tokens = history_scene_tokens[:, -history_steps_override:]
251
+ memory_output = self.memory(scene_tokens, history_scene_tokens=history_scene_tokens)
252
+
253
+ interaction_state = None
254
+ if use_interaction_head:
255
+ interaction_state = self.interaction_head(
256
+ scene_tokens,
257
+ memory_tokens=memory_output["memory_tokens"],
258
+ )
259
+ interaction_state["memory_tokens"] = memory_output["memory_tokens"]
260
+ interaction_state["memory_token"] = memory_output["memory_token"]
261
+
262
+ if interaction_state is not None and not use_role_tokens:
263
+ interaction_state = dict(interaction_state)
264
+ interaction_state["arm_role_logits"] = torch.zeros_like(interaction_state["arm_role_logits"])
265
+
266
+ decoded = self.decoder(
267
+ scene_tokens,
268
+ interaction_state=interaction_state,
269
+ memory_tokens=memory_output["memory_tokens"],
270
+ )
271
+ outputs = {
272
+ **decoded,
273
+ "scene_tokens": scene_tokens,
274
+ "history_scene_tokens": history_scene_tokens,
275
+ "memory_output": memory_output,
276
+ "memory_uncertainty": memory_output["memory_uncertainty"],
277
+ "interaction_state": interaction_state,
278
+ "reveal_state": interaction_state,
279
+ }
280
+
281
+ if plan:
282
+ candidate_chunks = candidate_chunks_override
283
+ proposal_logits = outputs.get("proposal_logits")
284
+ if candidate_chunks is None:
285
+ candidate_chunks = self.decoder.sample_candidates(
286
+ outputs["action_mean"],
287
+ outputs["action_log_std"],
288
+ num_candidates=self.config.decoder.num_candidates,
289
+ proposal_candidates=outputs.get("proposal_candidates"),
290
+ )
291
+ else:
292
+ proposal_logits = None
293
+ outputs["candidate_chunks"] = candidate_chunks
294
+
295
+ if interaction_state is None:
296
+ outputs["planned_chunk"] = outputs["action_mean"]
297
+ outputs["planner_success_logits"] = torch.zeros(
298
+ candidate_chunks.shape[:2],
299
+ device=candidate_chunks.device,
300
+ dtype=candidate_chunks.dtype,
301
+ )
302
+ outputs["planner_risk_values"] = torch.zeros_like(outputs["planner_success_logits"])
303
+ outputs["planner_scores"] = torch.zeros_like(outputs["planner_success_logits"])
304
+ outputs["best_candidate_indices"] = torch.zeros(
305
+ candidate_chunks.shape[0],
306
+ dtype=torch.long,
307
+ device=candidate_chunks.device,
308
+ )
309
+ outputs["planned_rollout"] = {}
310
+ return outputs
311
+
312
+ batch_size, num_candidates, chunk_size, action_dim = candidate_chunks.shape
313
+ flat_chunks = candidate_chunks.view(batch_size * num_candidates, chunk_size, action_dim)
314
+ tiled_scene = self._tile_tensor(scene_tokens, num_candidates)
315
+ planning_state = interaction_state
316
+ if not support_mode_conditioning:
317
+ planning_state = dict(interaction_state)
318
+ planning_state["support_mode_logits"] = torch.zeros_like(interaction_state["support_mode_logits"])
319
+ tiled_state = self._tile_state(planning_state, num_candidates)
320
+ tiled_memory_tokens = self._tile_tensor(memory_output["memory_tokens"], num_candidates)
321
+ rollout = self.world_model(
322
+ scene_tokens=tiled_scene,
323
+ interaction_state=tiled_state,
324
+ action_chunk=flat_chunks,
325
+ memory_tokens=tiled_memory_tokens,
326
+ )
327
+ reshaped_rollout = {
328
+ key: value.view(batch_size, num_candidates, *value.shape[1:]) for key, value in rollout.items()
329
+ }
330
+ selected = self.planner.select_best(
331
+ candidate_chunks=candidate_chunks,
332
+ rollout_state=reshaped_rollout,
333
+ proposal_logits=proposal_logits,
334
+ )
335
+ outputs["planned_rollout"] = reshaped_rollout
336
+ outputs["planned_chunk"] = selected["best_chunk"]
337
+ outputs["planner_success_logits"] = selected["success_logits"]
338
+ outputs["planner_risk_values"] = selected["risk_values"]
339
+ outputs["planner_scores"] = selected["utility_scores"]
340
+ outputs["best_candidate_indices"] = selected["best_indices"]
341
+ return outputs
code/reveal_vla_bimanual/models/reveal_head.py CHANGED
@@ -17,6 +17,9 @@ class RevealHeadConfig:
17
  field_size: int = 16
18
  num_heads: int = 4
19
  predict_belief_map: bool = False
 
 
 
20
 
21
 
22
  class RevealStateHead(nn.Module):
@@ -116,3 +119,201 @@ class RevealStateHead(nn.Module):
116
  if self.config.predict_belief_map:
117
  output["belief_map"] = belief_map
118
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  field_size: int = 16
18
  num_heads: int = 4
19
  predict_belief_map: bool = False
20
+ num_phases: int = 5
21
+ num_arm_roles: int = 4
22
+ num_interaction_tokens: int = 8
23
 
24
 
25
  class RevealStateHead(nn.Module):
 
119
  if self.config.predict_belief_map:
120
  output["belief_map"] = belief_map
121
  return output
122
+
123
+
124
+ class InteractionFieldDecoder(nn.Module):
125
+ def __init__(self, config: RevealHeadConfig) -> None:
126
+ super().__init__()
127
+ self.config = config
128
+ self.field_queries = nn.Parameter(
129
+ torch.randn(config.field_size * config.field_size, config.hidden_dim) * 0.02
130
+ )
131
+ self.field_attention = nn.MultiheadAttention(
132
+ embed_dim=config.hidden_dim,
133
+ num_heads=config.num_heads,
134
+ batch_first=True,
135
+ )
136
+ self.field_mlp = nn.Sequential(
137
+ nn.LayerNorm(config.hidden_dim),
138
+ nn.Linear(config.hidden_dim, config.hidden_dim),
139
+ nn.GELU(),
140
+ nn.Linear(config.hidden_dim, config.hidden_dim),
141
+ )
142
+ summary_dim = config.hidden_dim * 4
143
+ self.summary_proj = nn.Sequential(
144
+ nn.LayerNorm(summary_dim),
145
+ nn.Linear(summary_dim, config.hidden_dim),
146
+ nn.GELU(),
147
+ )
148
+ self.phase_head = nn.Sequential(
149
+ nn.LayerNorm(summary_dim),
150
+ nn.Linear(summary_dim, config.hidden_dim),
151
+ nn.GELU(),
152
+ nn.Linear(config.hidden_dim, config.num_phases),
153
+ )
154
+ self.arm_role_head = nn.Sequential(
155
+ nn.LayerNorm(config.hidden_dim * 2),
156
+ nn.Linear(config.hidden_dim * 2, config.hidden_dim),
157
+ nn.GELU(),
158
+ nn.Linear(config.hidden_dim, config.num_arm_roles),
159
+ )
160
+ self.arm_identity = nn.Embedding(2, config.hidden_dim)
161
+ self.support_mode = nn.Sequential(
162
+ nn.LayerNorm(summary_dim),
163
+ nn.Linear(summary_dim, config.hidden_dim),
164
+ nn.GELU(),
165
+ nn.Linear(config.hidden_dim, config.num_support_modes),
166
+ )
167
+ self.target_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
168
+ self.actor_feasibility_field = nn.Conv2d(config.hidden_dim, 2, kernel_size=1)
169
+ self.persistence_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
170
+ self.risk_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
171
+ self.uncertainty_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
172
+ self.compat_access_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1)
173
+ self.compat_persistence = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1)
174
+ self.reocclusion_head = nn.Sequential(
175
+ nn.LayerNorm(summary_dim),
176
+ nn.Linear(summary_dim, config.hidden_dim),
177
+ nn.GELU(),
178
+ nn.Linear(config.hidden_dim, config.num_support_modes),
179
+ )
180
+
181
+ def _pool_source(self, source_tokens: Tensor | None, fallback: Tensor) -> Tensor:
182
+ if source_tokens is None or source_tokens.numel() == 0:
183
+ return fallback.new_zeros(fallback.shape)
184
+ return source_tokens.mean(dim=1)
185
+
186
+ def forward(
187
+ self,
188
+ interaction_tokens: Tensor,
189
+ scene_tokens: Tensor | None = None,
190
+ memory_tokens: Tensor | None = None,
191
+ ) -> dict[str, Tensor]:
192
+ batch_size = interaction_tokens.shape[0]
193
+ pooled_interaction = interaction_tokens.mean(dim=1)
194
+ pooled_scene = self._pool_source(scene_tokens, pooled_interaction)
195
+ pooled_memory = self._pool_source(memory_tokens, pooled_interaction)
196
+
197
+ field_queries = self.field_queries.unsqueeze(0).expand(batch_size, -1, -1)
198
+ source_tokens = interaction_tokens
199
+ if scene_tokens is not None:
200
+ source_tokens = torch.cat([source_tokens, scene_tokens], dim=1)
201
+ if memory_tokens is not None:
202
+ source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
203
+ field_tokens, _ = self.field_attention(field_queries, source_tokens, source_tokens)
204
+ field_tokens = field_tokens + self.field_mlp(field_tokens)
205
+
206
+ side = self.config.field_size
207
+ grid = field_tokens.transpose(1, 2).reshape(batch_size, self.config.hidden_dim, side, side)
208
+ pooled_field = field_tokens.mean(dim=1)
209
+ summary_input = torch.cat([pooled_interaction, pooled_field, pooled_scene, pooled_memory], dim=-1)
210
+ summary = self.summary_proj(summary_input)
211
+
212
+ target_field = self.target_field(grid)
213
+ actor_feasibility_field = self.actor_feasibility_field(grid)
214
+ persistence_field = torch.sigmoid(self.persistence_field(grid))
215
+ risk_field = torch.sigmoid(self.risk_field(grid))
216
+ uncertainty_field = F.softplus(self.uncertainty_field(grid))
217
+
218
+ access_field = self.compat_access_field(grid)
219
+ corridor_source = access_field.amax(dim=-2)
220
+ corridor_logits = F.interpolate(
221
+ corridor_source,
222
+ size=self.config.num_approach_templates,
223
+ mode="linear",
224
+ align_corners=False,
225
+ )
226
+ compatibility_persistence = torch.sigmoid(self.compat_persistence(grid))
227
+ access_prob = torch.sigmoid(access_field)
228
+ weighted_persistence = (compatibility_persistence * access_prob).sum(dim=(-1, -2))
229
+ access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4)
230
+ persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass
231
+ disturbance_cost = risk_field.mean(dim=(-1, -2)).squeeze(1)
232
+ belief_map = target_field
233
+ if belief_map.shape[-1] != self.config.belief_map_size:
234
+ belief_map = F.interpolate(
235
+ belief_map,
236
+ size=(self.config.belief_map_size, self.config.belief_map_size),
237
+ mode="bilinear",
238
+ align_corners=False,
239
+ )
240
+
241
+ arm_identity = self.arm_identity.weight.unsqueeze(0).expand(batch_size, -1, -1)
242
+ if interaction_tokens.shape[1] >= 2:
243
+ arm_tokens = interaction_tokens[:, :2] + arm_identity
244
+ else:
245
+ arm_tokens = pooled_interaction.unsqueeze(1).expand(-1, 2, -1) + arm_identity
246
+ arm_role_input = torch.cat(
247
+ [arm_tokens, summary.unsqueeze(1).expand(-1, arm_tokens.shape[1], -1)],
248
+ dim=-1,
249
+ )
250
+ arm_role_logits = self.arm_role_head(arm_role_input)
251
+ reocclusion_logit = self.reocclusion_head(summary_input)
252
+
253
+ output = {
254
+ "phase_logits": self.phase_head(summary_input),
255
+ "arm_role_logits": arm_role_logits,
256
+ "target_field": target_field,
257
+ "actor_feasibility_field": actor_feasibility_field,
258
+ "persistence_field": persistence_field,
259
+ "risk_field": risk_field,
260
+ "uncertainty_field": uncertainty_field,
261
+ "interaction_tokens": interaction_tokens,
262
+ "field_tokens": field_tokens,
263
+ "latent_summary": summary,
264
+ "support_mode_logits": self.support_mode(summary_input),
265
+ "corridor_logits": corridor_logits,
266
+ "persistence_horizon": persistence_horizon,
267
+ "disturbance_cost": disturbance_cost,
268
+ "belief_map": belief_map,
269
+ "reocclusion_logit": reocclusion_logit,
270
+ "persistence_uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
271
+ "access_field": access_field,
272
+ "disturbance_field": risk_field,
273
+ "uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
274
+ }
275
+ if not self.config.predict_belief_map:
276
+ output.pop("belief_map")
277
+ return output
278
+
279
+
280
+ class InteractionStateHead(nn.Module):
281
+ def __init__(self, config: RevealHeadConfig) -> None:
282
+ super().__init__()
283
+ self.config = config
284
+ self.interaction_queries = nn.Parameter(
285
+ torch.randn(config.num_interaction_tokens, config.hidden_dim) * 0.02
286
+ )
287
+ self.interaction_attention = nn.MultiheadAttention(
288
+ embed_dim=config.hidden_dim,
289
+ num_heads=config.num_heads,
290
+ batch_first=True,
291
+ )
292
+ self.interaction_mlp = nn.Sequential(
293
+ nn.LayerNorm(config.hidden_dim),
294
+ nn.Linear(config.hidden_dim, config.hidden_dim),
295
+ nn.GELU(),
296
+ nn.Linear(config.hidden_dim, config.hidden_dim),
297
+ )
298
+ self.decoder = InteractionFieldDecoder(config)
299
+
300
+ def forward(
301
+ self,
302
+ scene_tokens: Tensor,
303
+ memory_token: Tensor | None = None,
304
+ memory_tokens: Tensor | None = None,
305
+ ) -> dict[str, Tensor]:
306
+ if memory_tokens is None:
307
+ memory_tokens = memory_token
308
+ source_tokens = scene_tokens
309
+ if memory_tokens is not None:
310
+ source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
311
+ batch_size = source_tokens.shape[0]
312
+ interaction_queries = self.interaction_queries.unsqueeze(0).expand(batch_size, -1, -1)
313
+ interaction_tokens, _ = self.interaction_attention(interaction_queries, source_tokens, source_tokens)
314
+ interaction_tokens = interaction_tokens + self.interaction_mlp(interaction_tokens)
315
+ return self.decoder(
316
+ interaction_tokens=interaction_tokens,
317
+ scene_tokens=scene_tokens,
318
+ memory_tokens=memory_tokens,
319
+ )
code/reveal_vla_bimanual/models/world_model.py CHANGED
@@ -5,6 +5,8 @@ from dataclasses import dataclass
5
  import torch
6
  from torch import Tensor, nn
7
 
 
 
8
 
9
  @dataclass
10
  class RevealWMConfig:
@@ -13,6 +15,13 @@ class RevealWMConfig:
13
  num_support_modes: int = 3
14
  num_approach_templates: int = 32
15
  rollout_horizon: int = 5
 
 
 
 
 
 
 
16
 
17
 
18
  class RevealWM(nn.Module):
@@ -78,3 +87,68 @@ class RevealWM(nn.Module):
78
  "reocclusion_logit": self.reocclusion(rollout),
79
  "uncertainty": torch.nn.functional.softplus(self.uncertainty(rollout)).squeeze(-1),
80
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  from torch import Tensor, nn
7
 
8
+ from models.reveal_head import InteractionFieldDecoder
9
+
10
 
11
  @dataclass
12
  class RevealWMConfig:
 
15
  num_support_modes: int = 3
16
  num_approach_templates: int = 32
17
  rollout_horizon: int = 5
18
+ field_size: int = 16
19
+ num_heads: int = 4
20
+ num_phases: int = 5
21
+ num_arm_roles: int = 4
22
+ num_interaction_tokens: int = 8
23
+ belief_map_size: int = 32
24
+ predict_belief_map: bool = True
25
 
26
 
27
  class RevealWM(nn.Module):
 
87
  "reocclusion_logit": self.reocclusion(rollout),
88
  "uncertainty": torch.nn.functional.softplus(self.uncertainty(rollout)).squeeze(-1),
89
  }
90
+
91
+
92
+ class InteractionRolloutModel(nn.Module):
93
+ def __init__(self, config: RevealWMConfig) -> None:
94
+ super().__init__()
95
+ self.config = config
96
+ self.action_encoder = nn.Sequential(
97
+ nn.LayerNorm(config.action_dim),
98
+ nn.Linear(config.action_dim, config.hidden_dim),
99
+ nn.GELU(),
100
+ )
101
+ encoder_layer = nn.TransformerEncoderLayer(
102
+ d_model=config.hidden_dim,
103
+ nhead=config.num_heads,
104
+ dim_feedforward=config.hidden_dim * 4,
105
+ batch_first=True,
106
+ norm_first=True,
107
+ )
108
+ self.transition = nn.TransformerEncoder(encoder_layer, num_layers=2)
109
+ self.token_update = nn.Sequential(
110
+ nn.LayerNorm(config.hidden_dim),
111
+ nn.Linear(config.hidden_dim, config.hidden_dim),
112
+ nn.GELU(),
113
+ nn.Linear(config.hidden_dim, config.hidden_dim),
114
+ )
115
+ self.decoder = InteractionFieldDecoder(config)
116
+
117
+ def forward(
118
+ self,
119
+ scene_tokens: Tensor,
120
+ interaction_state: dict[str, Tensor],
121
+ action_chunk: Tensor,
122
+ memory_tokens: Tensor | None = None,
123
+ ) -> dict[str, Tensor]:
124
+ if memory_tokens is None:
125
+ memory_tokens = interaction_state.get("memory_tokens")
126
+ if memory_tokens is None:
127
+ memory_tokens = interaction_state.get("memory_token")
128
+ current_tokens = interaction_state["interaction_tokens"]
129
+ outputs: dict[str, list[Tensor]] = {}
130
+
131
+ for step in range(action_chunk.shape[1]):
132
+ action_token = self.action_encoder(action_chunk[:, step]).unsqueeze(1)
133
+ transition_tokens = current_tokens
134
+ if memory_tokens is not None:
135
+ transition_tokens = torch.cat([transition_tokens, memory_tokens], dim=1)
136
+ transition_tokens = torch.cat([transition_tokens, action_token], dim=1)
137
+ transitioned = self.transition(transition_tokens)
138
+ current_tokens = current_tokens + self.token_update(transitioned[:, : current_tokens.shape[1]])
139
+ decoded = self.decoder(
140
+ interaction_tokens=current_tokens,
141
+ scene_tokens=scene_tokens,
142
+ memory_tokens=memory_tokens,
143
+ )
144
+ decoded["memory_token"] = (
145
+ memory_tokens.mean(dim=1, keepdim=True) if memory_tokens is not None else current_tokens.mean(dim=1, keepdim=True)
146
+ )
147
+ decoded["memory_tokens"] = memory_tokens if memory_tokens is not None else current_tokens[:, :1]
148
+ for key, value in decoded.items():
149
+ outputs.setdefault(key, []).append(value)
150
+
151
+ stacked: dict[str, Tensor] = {}
152
+ for key, values in outputs.items():
153
+ stacked[key] = torch.stack(values, dim=1)
154
+ return stacked
code/reveal_vla_bimanual/scripts/setup_env_a_rlbench.sh CHANGED
@@ -57,11 +57,28 @@ run_in_env python -m pip install -U pip setuptools wheel
57
  run_in_env python -m pip install --force-reinstall --no-deps numpy==1.26.4 Pillow==12.1.1
58
  run_in_env python -m pip install --no-deps moviepy==2.2.1 timeout-decorator==0.5.0 opencv-python==4.10.0.84 pyquaternion==0.9.9 click-prompt==0.5.1
59
  run_in_env python -m pip install --no-deps poetry-core
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  run_in_env python -m pip install --no-build-isolation -e "${PROJECT_DIR}"
61
- run_in_env python -m pip install --no-build-isolation --no-deps -e "${ROOT_DIR}/third_party/YARR"
62
- run_in_env python -m pip install --no-build-isolation --no-deps -e "${ROOT_DIR}/third_party/PyRep"
63
- run_in_env python -m pip install --no-build-isolation --no-deps -e "${ROOT_DIR}/third_party/RLBench"
64
- run_in_env python -m pip install --no-build-isolation --no-deps -e "${ROOT_DIR}/third_party/peract_bimanual"
65
 
66
  printf 'RLBench env ready at %s\n' "${ENV_PREFIX}"
67
  printf 'Activate with:\n'
 
57
  run_in_env python -m pip install --force-reinstall --no-deps numpy==1.26.4 Pillow==12.1.1
58
  run_in_env python -m pip install --no-deps moviepy==2.2.1 timeout-decorator==0.5.0 opencv-python==4.10.0.84 pyquaternion==0.9.9 click-prompt==0.5.1
59
  run_in_env python -m pip install --no-deps poetry-core
60
+ run_in_env python -m pip install gymnasium==1.0.0a2 open3d==0.19.0 segment-anything==1.0 transforms3d==0.4.1 openai==0.28.1
61
+ PERACT_ROOT="${ROOT_DIR}/third_party/peract_bimanual"
62
+ YARR_ROOT="${ROOT_DIR}/third_party/YARR"
63
+ PYREP_ROOT="${ROOT_DIR}/third_party/PyRep"
64
+ RLBENCH_ROOT="${ROOT_DIR}/third_party/RLBench"
65
+ if [[ -f "${PERACT_ROOT}/YARR/setup.py" ]]; then
66
+ YARR_ROOT="${PERACT_ROOT}/YARR"
67
+ fi
68
+ if [[ -f "${PERACT_ROOT}/PyRep/setup.py" ]]; then
69
+ PYREP_ROOT="${PERACT_ROOT}/PyRep"
70
+ fi
71
+ if [[ -f "${PERACT_ROOT}/RLBench/setup.py" ]]; then
72
+ RLBENCH_ROOT="${PERACT_ROOT}/RLBench"
73
+ fi
74
+ if [[ ! -f "${PERACT_ROOT}/pyproject.toml" && ! -f "${PERACT_ROOT}/setup.py" && -f "${PERACT_ROOT}/peract/setup.py" ]]; then
75
+ PERACT_ROOT="${PERACT_ROOT}/peract"
76
+ fi
77
  run_in_env python -m pip install --no-build-isolation -e "${PROJECT_DIR}"
78
+ run_in_env python -m pip install --no-build-isolation --no-deps -e "${YARR_ROOT}"
79
+ run_in_env python -m pip install --no-build-isolation --no-deps -e "${PYREP_ROOT}"
80
+ run_in_env python -m pip install --no-build-isolation --no-deps -e "${RLBENCH_ROOT}"
81
+ run_in_env python -m pip install --no-build-isolation --no-deps -e "${PERACT_ROOT}"
82
 
83
  printf 'RLBench env ready at %s\n' "${ENV_PREFIX}"
84
  printf 'Activate with:\n'
code/reveal_vla_bimanual/scripts/setup_rlbench_headless_x.sh CHANGED
@@ -9,9 +9,19 @@ export DEBIAN_FRONTEND=noninteractive
9
 
10
  apt-get update
11
  apt-get install -y \
 
 
 
 
 
 
 
 
12
  libxkbcommon0 \
13
  libxkbcommon-x11-0 \
14
  mesa-utils \
 
 
15
  x11-xserver-utils \
16
  xauth \
17
  xserver-xorg \
 
9
 
10
  apt-get update
11
  apt-get install -y \
12
+ libxcb-cursor0 \
13
+ libxcb-icccm4 \
14
+ libxcb-image0 \
15
+ libxcb-keysyms1 \
16
+ libxcb-randr0 \
17
+ libxcb-render-util0 \
18
+ libxcb-xinerama0 \
19
+ libxrender1 \
20
  libxkbcommon0 \
21
  libxkbcommon-x11-0 \
22
  mesa-utils \
23
+ xvfb \
24
+ x11-utils \
25
  x11-xserver-utils \
26
  xauth \
27
  xserver-xorg \
code/reveal_vla_bimanual/sim_rlbench/camera_spec.py CHANGED
@@ -2,6 +2,21 @@ from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  @dataclass(frozen=True)
7
  class RLBenchThreeCameraSpec:
@@ -16,8 +31,12 @@ class RLBenchThreeCameraSpec:
16
  def wrist_cameras(self) -> tuple[str, str]:
17
  return self.cameras[1], self.cameras[2]
18
 
 
 
 
 
19
  def hydra_overrides(self, prefix: str = "rlbench") -> list[str]:
20
- camera_list = ",".join(self.cameras)
21
  height, width = self.resolution
22
  return [
23
  f"{prefix}.cameras=[{camera_list}]",
 
2
 
3
  from dataclasses import dataclass
4
 
5
+ CANONICAL_TO_UPSTREAM_CAMERA = {
6
+ "front": "front",
7
+ "wrist_left": "wrist2",
8
+ "wrist_right": "wrist",
9
+ }
10
+ UPSTREAM_TO_CANONICAL_CAMERA = {value: key for key, value in CANONICAL_TO_UPSTREAM_CAMERA.items()}
11
+
12
+
13
+ def canonical_to_upstream_camera(camera_name: str) -> str:
14
+ return CANONICAL_TO_UPSTREAM_CAMERA.get(camera_name, camera_name)
15
+
16
+
17
+ def upstream_to_canonical_camera(camera_name: str) -> str:
18
+ return UPSTREAM_TO_CANONICAL_CAMERA.get(camera_name, camera_name)
19
+
20
 
21
  @dataclass(frozen=True)
22
  class RLBenchThreeCameraSpec:
 
31
  def wrist_cameras(self) -> tuple[str, str]:
32
  return self.cameras[1], self.cameras[2]
33
 
34
+ @property
35
+ def upstream_cameras(self) -> tuple[str, str, str]:
36
+ return tuple(CANONICAL_TO_UPSTREAM_CAMERA.get(camera, camera) for camera in self.cameras) # type: ignore[return-value]
37
+
38
  def hydra_overrides(self, prefix: str = "rlbench") -> list[str]:
39
+ camera_list = ",".join(self.upstream_cameras)
40
  height, width = self.resolution
41
  return [
42
  f"{prefix}.cameras=[{camera_list}]",
code/reveal_vla_bimanual/sim_rlbench/dataset.py CHANGED
@@ -10,10 +10,66 @@ import torch
10
  from PIL import Image
11
  from torch.utils.data import Dataset
12
 
 
 
13
 
14
  THREE_CAMERAS: tuple[str, str, str] = ("front", "wrist_left", "wrist_right")
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def _normalize_quaternion_wxyz(quaternion: np.ndarray) -> np.ndarray:
18
  quaternion = np.asarray(quaternion, dtype=np.float32)
19
  return quaternion / max(float(np.linalg.norm(quaternion)), 1e-8)
@@ -90,11 +146,11 @@ def bimanual_proprio_from_obs(
90
  )
91
  base = np.concatenate(
92
  [
93
- np.asarray(obs.right.gripper_pose, dtype=np.float32),
94
- np.asarray(obs.left.gripper_pose, dtype=np.float32),
95
- np.asarray(obs.right.joint_positions, dtype=np.float32),
96
- np.asarray(obs.left.joint_positions, dtype=np.float32),
97
- np.array([float(obs.right.gripper_open), float(obs.left.gripper_open)], dtype=np.float32),
98
  time_feature,
99
  ],
100
  axis=0,
@@ -108,10 +164,8 @@ def bimanual_proprio_from_obs(
108
  def delta_action_from_transition(current_obs: Any, next_obs: Any) -> np.ndarray:
109
  action_parts: list[np.ndarray] = []
110
  for arm_name in ("right", "left"):
111
- current_arm = getattr(current_obs, arm_name)
112
- next_arm = getattr(next_obs, arm_name)
113
- current_pose = np.asarray(current_arm.gripper_pose, dtype=np.float32)
114
- next_pose = np.asarray(next_arm.gripper_pose, dtype=np.float32)
115
  position_delta = next_pose[:3] - current_pose[:3]
116
  current_quat = _xyzw_to_wxyz(current_pose[3:])
117
  next_quat = _xyzw_to_wxyz(next_pose[3:])
@@ -122,7 +176,7 @@ def delta_action_from_transition(current_obs: Any, next_obs: Any) -> np.ndarray:
122
  [
123
  position_delta.astype(np.float32),
124
  delta_rotvec.astype(np.float32),
125
- np.array([float(next_arm.gripper_open)], dtype=np.float32),
126
  ],
127
  axis=0,
128
  )
@@ -130,39 +184,55 @@ def delta_action_from_transition(current_obs: Any, next_obs: Any) -> np.ndarray:
130
  return np.concatenate(action_parts, axis=0).astype(np.float32)
131
 
132
 
133
- def absolute_action_from_delta(current_obs: Any, delta_action: Sequence[float], ignore_collisions: bool = True) -> np.ndarray:
 
 
 
 
 
134
  delta_action = np.asarray(delta_action, dtype=np.float32)
135
  if delta_action.shape != (14,):
136
  raise ValueError(f"Expected delta action shape (14,), received {delta_action.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- env_action: list[float] = []
139
- for arm_index, arm_name in enumerate(("right", "left")):
140
- arm = getattr(current_obs, arm_name)
141
- current_pose = np.asarray(arm.gripper_pose, dtype=np.float32)
142
- offset = arm_index * 7
143
- delta_position = delta_action[offset : offset + 3]
144
- delta_rotvec = delta_action[offset + 3 : offset + 6]
145
- gripper = float(delta_action[offset + 6] > 0.5)
146
- current_quat = _xyzw_to_wxyz(current_pose[3:])
147
- delta_quat = _rotvec_to_quat_wxyz(delta_rotvec)
148
- next_quat = _quat_multiply_wxyz(delta_quat, current_quat)
149
- next_pose = np.concatenate(
150
- [
151
- current_pose[:3] + delta_position,
152
- _wxyz_to_xyzw(next_quat),
153
- ],
154
- axis=0,
155
- )
156
- env_action.extend(next_pose.tolist())
157
- env_action.append(gripper)
158
- env_action.append(float(ignore_collisions))
159
- return np.asarray(env_action, dtype=np.float32)
160
 
161
 
162
  def stack_live_rgb_obs(obs: Any, cameras: Sequence[str] = THREE_CAMERAS, resolution: int = 224) -> torch.Tensor:
163
  images: list[np.ndarray] = []
164
  for camera_name in cameras:
165
- rgb = np.asarray(obs.perception_data[f"{camera_name}_rgb"], dtype=np.uint8)
 
 
 
166
  image = Image.fromarray(rgb)
167
  if image.size != (resolution, resolution):
168
  image = image.resize((resolution, resolution), Image.Resampling.BILINEAR)
@@ -259,6 +329,9 @@ class RLBenchOfflineChunkDataset(Dataset[dict[str, Any]]):
259
  frames: list[np.ndarray] = []
260
  for camera_name in self.cameras:
261
  image_path = episode_dir / f"{camera_name}_rgb" / f"rgb_{step_index:04d}.png"
 
 
 
262
  image = Image.open(image_path).convert("RGB")
263
  if image.size != (self.resolution, self.resolution):
264
  image = image.resize((self.resolution, self.resolution), Image.Resampling.BILINEAR)
 
10
  from PIL import Image
11
  from torch.utils.data import Dataset
12
 
13
+ from sim_rlbench.camera_spec import canonical_to_upstream_camera
14
+
15
 
16
  THREE_CAMERAS: tuple[str, str, str] = ("front", "wrist_left", "wrist_right")
17
 
18
 
19
+ def _camera_value(obs: Any, camera_name: str, suffix: str) -> Any:
20
+ upstream_name = canonical_to_upstream_camera(camera_name)
21
+ candidate_keys = [
22
+ f"{upstream_name}_{suffix}",
23
+ f"{camera_name}_{suffix}",
24
+ ]
25
+ if suffix == "point_cloud":
26
+ candidate_keys.extend(
27
+ [
28
+ f"{upstream_name}_pointcloud",
29
+ f"{camera_name}_pointcloud",
30
+ ]
31
+ )
32
+ for key in candidate_keys:
33
+ if hasattr(obs, key):
34
+ return getattr(obs, key)
35
+ perception_data = getattr(obs, "perception_data", None)
36
+ if isinstance(perception_data, dict):
37
+ for key in candidate_keys:
38
+ if key in perception_data:
39
+ return perception_data[key]
40
+ return None
41
+
42
+
43
+ def _arm_pose(obs: Any, arm_name: str) -> np.ndarray:
44
+ key = f"gripper_{arm_name}_pose"
45
+ if hasattr(obs, key):
46
+ return np.asarray(getattr(obs, key), dtype=np.float32)
47
+ arm = getattr(obs, arm_name, None)
48
+ if arm is not None and hasattr(arm, "gripper_pose"):
49
+ return np.asarray(arm.gripper_pose, dtype=np.float32)
50
+ raise AttributeError(f"Observation does not expose pose for arm '{arm_name}'")
51
+
52
+
53
+ def _arm_joint_positions(obs: Any, arm_name: str) -> np.ndarray:
54
+ key = f"joint_positions_{arm_name}"
55
+ if hasattr(obs, key):
56
+ return np.asarray(getattr(obs, key), dtype=np.float32)
57
+ arm = getattr(obs, arm_name, None)
58
+ if arm is not None and hasattr(arm, "joint_positions"):
59
+ return np.asarray(arm.joint_positions, dtype=np.float32)
60
+ raise AttributeError(f"Observation does not expose joint positions for arm '{arm_name}'")
61
+
62
+
63
+ def _arm_gripper_open(obs: Any, arm_name: str) -> float:
64
+ key = f"gripper_{arm_name}_open"
65
+ if hasattr(obs, key):
66
+ return float(getattr(obs, key))
67
+ arm = getattr(obs, arm_name, None)
68
+ if arm is not None and hasattr(arm, "gripper_open"):
69
+ return float(arm.gripper_open)
70
+ raise AttributeError(f"Observation does not expose gripper state for arm '{arm_name}'")
71
+
72
+
73
  def _normalize_quaternion_wxyz(quaternion: np.ndarray) -> np.ndarray:
74
  quaternion = np.asarray(quaternion, dtype=np.float32)
75
  return quaternion / max(float(np.linalg.norm(quaternion)), 1e-8)
 
146
  )
147
  base = np.concatenate(
148
  [
149
+ _arm_pose(obs, "right"),
150
+ _arm_pose(obs, "left"),
151
+ _arm_joint_positions(obs, "right"),
152
+ _arm_joint_positions(obs, "left"),
153
+ np.array([_arm_gripper_open(obs, "right"), _arm_gripper_open(obs, "left")], dtype=np.float32),
154
  time_feature,
155
  ],
156
  axis=0,
 
164
  def delta_action_from_transition(current_obs: Any, next_obs: Any) -> np.ndarray:
165
  action_parts: list[np.ndarray] = []
166
  for arm_name in ("right", "left"):
167
+ current_pose = _arm_pose(current_obs, arm_name)
168
+ next_pose = _arm_pose(next_obs, arm_name)
 
 
169
  position_delta = next_pose[:3] - current_pose[:3]
170
  current_quat = _xyzw_to_wxyz(current_pose[3:])
171
  next_quat = _xyzw_to_wxyz(next_pose[3:])
 
176
  [
177
  position_delta.astype(np.float32),
178
  delta_rotvec.astype(np.float32),
179
+ np.array([_arm_gripper_open(next_obs, arm_name)], dtype=np.float32),
180
  ],
181
  axis=0,
182
  )
 
184
  return np.concatenate(action_parts, axis=0).astype(np.float32)
185
 
186
 
187
+ def single_arm_absolute_action_from_delta(
188
+ current_obs: Any,
189
+ delta_action: Sequence[float],
190
+ arm_name: str,
191
+ ignore_collisions: bool = True,
192
+ ) -> np.ndarray:
193
  delta_action = np.asarray(delta_action, dtype=np.float32)
194
  if delta_action.shape != (14,):
195
  raise ValueError(f"Expected delta action shape (14,), received {delta_action.shape}")
196
+ arm_index = {"right": 0, "left": 1}[arm_name]
197
+ current_pose = _arm_pose(current_obs, arm_name)
198
+ offset = arm_index * 7
199
+ delta_position = delta_action[offset : offset + 3]
200
+ delta_rotvec = delta_action[offset + 3 : offset + 6]
201
+ gripper = float(delta_action[offset + 6] > 0.5)
202
+ current_quat = _xyzw_to_wxyz(current_pose[3:])
203
+ delta_quat = _rotvec_to_quat_wxyz(delta_rotvec)
204
+ next_quat = _quat_multiply_wxyz(delta_quat, current_quat)
205
+ next_pose = np.concatenate(
206
+ [
207
+ current_pose[:3] + delta_position,
208
+ _wxyz_to_xyzw(next_quat),
209
+ ],
210
+ axis=0,
211
+ )
212
+ return np.concatenate(
213
+ [
214
+ next_pose.astype(np.float32),
215
+ np.array([gripper, float(ignore_collisions)], dtype=np.float32),
216
+ ],
217
+ axis=0,
218
+ )
219
 
220
+
221
+ def absolute_action_from_delta(current_obs: Any, delta_action: Sequence[float], ignore_collisions: bool = True) -> np.ndarray:
222
+ arm_actions = [
223
+ single_arm_absolute_action_from_delta(current_obs, delta_action, arm_name, ignore_collisions=ignore_collisions)
224
+ for arm_name in ("right", "left")
225
+ ]
226
+ return np.concatenate(arm_actions, axis=0).astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
 
229
  def stack_live_rgb_obs(obs: Any, cameras: Sequence[str] = THREE_CAMERAS, resolution: int = 224) -> torch.Tensor:
230
  images: list[np.ndarray] = []
231
  for camera_name in cameras:
232
+ rgb_value = _camera_value(obs, camera_name, "rgb")
233
+ if rgb_value is None:
234
+ raise KeyError(f"Observation does not expose RGB for camera '{camera_name}'")
235
+ rgb = np.asarray(rgb_value, dtype=np.uint8)
236
  image = Image.fromarray(rgb)
237
  if image.size != (resolution, resolution):
238
  image = image.resize((resolution, resolution), Image.Resampling.BILINEAR)
 
329
  frames: list[np.ndarray] = []
330
  for camera_name in self.cameras:
331
  image_path = episode_dir / f"{camera_name}_rgb" / f"rgb_{step_index:04d}.png"
332
+ if not image_path.exists():
333
+ upstream_camera_name = canonical_to_upstream_camera(camera_name)
334
+ image_path = episode_dir / f"{upstream_camera_name}_rgb" / f"rgb_{step_index:04d}.png"
335
  image = Image.open(image_path).convert("RGB")
336
  if image.size != (self.resolution, self.resolution):
337
  image = image.resize((self.resolution, self.resolution), Image.Resampling.BILINEAR)
code/reveal_vla_bimanual/sim_rlbench/generate_smoke_dataset.py CHANGED
@@ -1,17 +1,17 @@
1
  from __future__ import annotations
2
 
3
  import argparse
 
4
  import pickle
5
  from pathlib import Path
6
 
7
  import numpy as np
8
  from PIL import Image
9
- from pyrep.const import RenderMode
10
- from rlbench.action_modes.action_mode import BimanualMoveArmThenGripper
11
- from rlbench.action_modes.arm_action_modes import BimanualJointPosition
12
- from rlbench.action_modes.gripper_action_modes import BimanualDiscrete
13
  from rlbench.backend.const import (
14
- DEPTH_SCALE,
15
  EPISODE_FOLDER,
16
  EPISODES_FOLDER,
17
  LOW_DIM_PICKLE,
@@ -19,11 +19,24 @@ from rlbench.backend.const import (
19
  VARIATION_NUMBER,
20
  VARIATIONS_ALL_FOLDER,
21
  )
22
- from rlbench.backend.utils import float_array_to_rgb_image, task_file_to_task_class
23
- from rlbench.environment import Environment
24
- from rlbench.observation_config import CameraConfig, ObservationConfig
25
 
26
- from sim_rlbench.camera_spec import default_three_camera_spec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def _save_demo(demo, episode_path: Path, cameras: list[str]) -> None:
@@ -33,19 +46,14 @@ def _save_demo(demo, episode_path: Path, cameras: list[str]) -> None:
33
  for dtype in data_types:
34
  output_dir = episode_path / f"{camera_name}_{dtype}"
35
  output_dir.mkdir(parents=True, exist_ok=True)
36
- payload = obs.perception_data.get(f"{camera_name}_{dtype}")
37
  if payload is None:
38
  continue
39
  if dtype == "rgb":
40
- image = Image.fromarray(payload)
41
- elif dtype == "depth":
42
- image = float_array_to_rgb_image(payload, scale_factor=DEPTH_SCALE)
43
- elif dtype == "mask":
44
- image = Image.fromarray((payload * 255).astype(np.uint8))
45
  else:
46
- raise ValueError(dtype)
47
  image.save(output_dir / f"{dtype}_{obs_idx:04d}.png")
48
- obs.perception_data.clear()
49
 
50
  with (episode_path / LOW_DIM_PICKLE).open("wb") as handle:
51
  pickle.dump(demo, handle)
@@ -53,43 +61,31 @@ def _save_demo(demo, episode_path: Path, cameras: list[str]) -> None:
53
 
54
  def main() -> None:
55
  parser = argparse.ArgumentParser()
56
- parser.add_argument("--task", default="bimanual_lift_ball")
57
  parser.add_argument("--episodes", type=int, default=1)
58
  parser.add_argument("--resolution", type=int, default=224)
59
- parser.add_argument("--output-root", default="/workspace/data/rlbench2_smoke")
 
 
60
  args = parser.parse_args()
61
 
62
  spec = default_three_camera_spec(args.resolution)
63
- camera_config = CameraConfig(
64
- rgb=True,
65
- depth=True,
66
- point_cloud=False,
67
- mask=True,
68
- image_size=list(spec.resolution),
69
- render_mode=RenderMode.OPENGL,
70
- masks_as_one_channel=False,
71
- depth_in_meters=False,
72
- )
73
- obs_config = ObservationConfig(
74
- camera_configs={camera_name: camera_config for camera_name in spec.cameras},
75
- joint_forces=False,
76
- joint_positions=True,
77
- joint_velocities=True,
78
- task_low_dim_state=False,
79
- gripper_touch_forces=False,
80
- gripper_pose=True,
81
- gripper_open=True,
82
- gripper_matrix=True,
83
- gripper_joint_positions=True,
84
- robot_name="bimanual",
85
  )
86
 
87
- task_class = task_file_to_task_class(args.task, bimanual=True)
88
- env = Environment(
89
- action_mode=BimanualMoveArmThenGripper(BimanualJointPosition(), BimanualDiscrete()),
 
 
 
90
  obs_config=obs_config,
91
- robot_setup="dual_panda",
92
  headless=True,
 
 
93
  )
94
  output_root = Path(args.output_root)
95
  episodes_root = output_root / args.task / VARIATIONS_ALL_FOLDER / EPISODES_FOLDER
@@ -105,8 +101,17 @@ def main() -> None:
105
  task_env = env.get_task(task_class)
106
  variation = int(rng.integers(variation_count))
107
  task_env.set_variation(variation)
108
- descriptions, _ = task_env.reset()
109
- (demo,) = task_env.get_demos(amount=1, live_demos=True)
 
 
 
 
 
 
 
 
 
110
  episode_path = episodes_root / (EPISODE_FOLDER % episode_idx)
111
  episode_path.mkdir(parents=True, exist_ok=True)
112
  _save_demo(demo, episode_path, list(spec.cameras))
@@ -115,7 +120,7 @@ def main() -> None:
115
  with (episode_path / VARIATION_DESCRIPTIONS).open("wb") as handle:
116
  pickle.dump(descriptions, handle)
117
  print(
118
- f"[done] wrote {args.task} episode {episode_idx} variation {variation} to {episode_path}",
119
  flush=True,
120
  )
121
  finally:
 
1
  from __future__ import annotations
2
 
3
  import argparse
4
+ import copy
5
  import pickle
6
  from pathlib import Path
7
 
8
  import numpy as np
9
  from PIL import Image
10
+ from helpers.utils import create_obs_config
11
+ from rlbench.action_modes.action_mode import MoveArmThenGripper2Robots
12
+ from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaPlanning2Robots
13
+ from rlbench.action_modes.gripper_action_modes import Discrete2Robots
14
  from rlbench.backend.const import (
 
15
  EPISODE_FOLDER,
16
  EPISODES_FOLDER,
17
  LOW_DIM_PICKLE,
 
19
  VARIATION_NUMBER,
20
  VARIATIONS_ALL_FOLDER,
21
  )
22
+ from rlbench.backend.utils import task_file_to_task_class
23
+ from rlbench.environments_two_robots import Environment2Robots
 
24
 
25
+ from sim_rlbench.camera_spec import canonical_to_upstream_camera, default_three_camera_spec
26
+
27
+
28
+ def _camera_payload(obs: object, camera_name: str, suffix: str):
29
+ upstream_name = canonical_to_upstream_camera(camera_name)
30
+ for key in (f"{upstream_name}_{suffix}", f"{camera_name}_{suffix}"):
31
+ if hasattr(obs, key):
32
+ return getattr(obs, key)
33
+ return None
34
+
35
+
36
+ def _scripted_demo(task_env: object, steps_per_episode: int) -> tuple[list[str], list[object]]:
37
+ descriptions, obs = task_env.reset()
38
+ demo = [copy.deepcopy(obs) for _ in range(max(steps_per_episode, 2))]
39
+ return descriptions, demo
40
 
41
 
42
  def _save_demo(demo, episode_path: Path, cameras: list[str]) -> None:
 
46
  for dtype in data_types:
47
  output_dir = episode_path / f"{camera_name}_{dtype}"
48
  output_dir.mkdir(parents=True, exist_ok=True)
49
+ payload = _camera_payload(obs, camera_name, dtype)
50
  if payload is None:
51
  continue
52
  if dtype == "rgb":
53
+ image = Image.fromarray(np.asarray(payload, dtype=np.uint8))
 
 
 
 
54
  else:
55
+ image = Image.fromarray((np.asarray(payload) * 255).astype(np.uint8))
56
  image.save(output_dir / f"{dtype}_{obs_idx:04d}.png")
 
57
 
58
  with (episode_path / LOW_DIM_PICKLE).open("wb") as handle:
59
  pickle.dump(demo, handle)
 
61
 
62
  def main() -> None:
63
  parser = argparse.ArgumentParser()
64
+ parser.add_argument("--task", default="open_drawer")
65
  parser.add_argument("--episodes", type=int, default=1)
66
  parser.add_argument("--resolution", type=int, default=224)
67
+ parser.add_argument("--output-root", default="/workspace/data/rlbench_smoke")
68
+ parser.add_argument("--steps-per-episode", type=int, default=6)
69
+ parser.add_argument("--try-live-demos", action="store_true")
70
  args = parser.parse_args()
71
 
72
  spec = default_three_camera_spec(args.resolution)
73
+ obs_config = create_obs_config(
74
+ list(spec.upstream_cameras),
75
+ [args.resolution, args.resolution],
76
+ "PERACT_BC",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  )
78
 
79
+ task_class = task_file_to_task_class(args.task)
80
+ env = Environment2Robots(
81
+ action_mode=MoveArmThenGripper2Robots(
82
+ EndEffectorPoseViaPlanning2Robots(absolute_mode=True, frame="world", collision_checking=False),
83
+ Discrete2Robots(),
84
+ ),
85
  obs_config=obs_config,
 
86
  headless=True,
87
+ robot_setup="panda",
88
+ task_name=task_class.__name__,
89
  )
90
  output_root = Path(args.output_root)
91
  episodes_root = output_root / args.task / VARIATIONS_ALL_FOLDER / EPISODES_FOLDER
 
101
  task_env = env.get_task(task_class)
102
  variation = int(rng.integers(variation_count))
103
  task_env.set_variation(variation)
104
+ if args.try_live_demos:
105
+ collection_mode = "live_demo"
106
+ try:
107
+ (demo,) = task_env.get_demos(amount=1, live_demos=True, max_attempts=1)
108
+ descriptions = task_env.get_task_descriptions()
109
+ except Exception as exc:
110
+ collection_mode = f"scripted_fallback: {exc}"
111
+ descriptions, demo = _scripted_demo(task_env, steps_per_episode=args.steps_per_episode)
112
+ else:
113
+ collection_mode = "scripted"
114
+ descriptions, demo = _scripted_demo(task_env, steps_per_episode=args.steps_per_episode)
115
  episode_path = episodes_root / (EPISODE_FOLDER % episode_idx)
116
  episode_path.mkdir(parents=True, exist_ok=True)
117
  _save_demo(demo, episode_path, list(spec.cameras))
 
120
  with (episode_path / VARIATION_DESCRIPTIONS).open("wb") as handle:
121
  pickle.dump(descriptions, handle)
122
  print(
123
+ f"[done] wrote {args.task} episode {episode_idx} variation {variation} via {collection_mode} to {episode_path}",
124
  flush=True,
125
  )
126
  finally:
code/reveal_vla_bimanual/sim_rlbench/obs_adapter.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any
5
 
6
  import numpy as np
7
 
8
- from sim_rlbench.camera_spec import RLBenchThreeCameraSpec
9
 
10
 
11
  @dataclass
@@ -31,18 +31,66 @@ class CanonicalBimanualObservation:
31
 
32
 
33
  def _camera_rgb(obs: Any, camera_name: str) -> np.ndarray:
34
- value = obs.perception_data[f"{camera_name}_rgb"]
35
- return np.asarray(value, dtype=np.uint8)
 
 
 
 
 
 
 
 
36
 
37
 
38
  def _camera_point_cloud(obs: Any, camera_name: str) -> np.ndarray:
39
- value = obs.perception_data[f"{camera_name}_point_cloud"]
40
- return np.asarray(value, dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  def _bimanual_proprio(obs: Any, timestep: int | None = None, episode_length: int | None = None) -> np.ndarray:
44
- right = np.asarray(obs.get_low_dim_data(obs.right), dtype=np.float32)
45
- left = np.asarray(obs.get_low_dim_data(obs.left), dtype=np.float32)
 
 
 
 
 
 
 
 
46
  proprio = np.concatenate([right, left], axis=0)
47
  if timestep is not None and episode_length and episode_length > 1:
48
  time_feature = np.array(
@@ -64,11 +112,11 @@ def extract_canonical_bimanual_obs(
64
  camera_spec = camera_spec or RLBenchThreeCameraSpec()
65
  rgb = {camera: _camera_rgb(obs, camera) for camera in camera_spec.cameras}
66
  intrinsics = {
67
- camera: np.asarray(obs.misc[f"{camera}_camera_intrinsics"], dtype=np.float32)
68
  for camera in camera_spec.cameras
69
  }
70
  extrinsics = {
71
- camera: np.asarray(obs.misc[f"{camera}_camera_extrinsics"], dtype=np.float32)
72
  for camera in camera_spec.cameras
73
  }
74
  point_cloud = None
 
5
 
6
  import numpy as np
7
 
8
+ from sim_rlbench.camera_spec import RLBenchThreeCameraSpec, canonical_to_upstream_camera
9
 
10
 
11
  @dataclass
 
31
 
32
 
33
  def _camera_rgb(obs: Any, camera_name: str) -> np.ndarray:
34
+ upstream_name = canonical_to_upstream_camera(camera_name)
35
+ for key in (f"{upstream_name}_rgb", f"{camera_name}_rgb"):
36
+ if hasattr(obs, key):
37
+ return np.asarray(getattr(obs, key), dtype=np.uint8)
38
+ perception_data = getattr(obs, "perception_data", None)
39
+ if isinstance(perception_data, dict):
40
+ for key in (f"{upstream_name}_rgb", f"{camera_name}_rgb"):
41
+ if key in perception_data:
42
+ return np.asarray(perception_data[key], dtype=np.uint8)
43
+ raise KeyError(f"Observation does not expose RGB for camera '{camera_name}'")
44
 
45
 
46
  def _camera_point_cloud(obs: Any, camera_name: str) -> np.ndarray:
47
+ upstream_name = canonical_to_upstream_camera(camera_name)
48
+ for key in (
49
+ f"{upstream_name}_point_cloud",
50
+ f"{upstream_name}_pointcloud",
51
+ f"{camera_name}_point_cloud",
52
+ f"{camera_name}_pointcloud",
53
+ ):
54
+ if hasattr(obs, key):
55
+ return np.asarray(getattr(obs, key), dtype=np.float32)
56
+ perception_data = getattr(obs, "perception_data", None)
57
+ if isinstance(perception_data, dict):
58
+ for key in (
59
+ f"{upstream_name}_point_cloud",
60
+ f"{upstream_name}_pointcloud",
61
+ f"{camera_name}_point_cloud",
62
+ f"{camera_name}_pointcloud",
63
+ ):
64
+ if key in perception_data:
65
+ return np.asarray(perception_data[key], dtype=np.float32)
66
+ raise KeyError(f"Observation does not expose point clouds for camera '{camera_name}'")
67
+
68
+
69
+ def _camera_misc(obs: Any, camera_name: str, field_name: str) -> np.ndarray:
70
+ upstream_name = canonical_to_upstream_camera(camera_name)
71
+ misc = getattr(obs, "misc", {})
72
+ for key in (
73
+ f"{upstream_name}_camera_{field_name}",
74
+ f"{camera_name}_camera_{field_name}",
75
+ f"{upstream_name}_{field_name}",
76
+ f"{camera_name}_{field_name}",
77
+ ):
78
+ if key in misc:
79
+ return np.asarray(misc[key], dtype=np.float32)
80
+ raise KeyError(f"Observation misc does not expose {field_name} for camera '{camera_name}'")
81
 
82
 
83
  def _bimanual_proprio(obs: Any, timestep: int | None = None, episode_length: int | None = None) -> np.ndarray:
84
+ if hasattr(obs, "get_low_dim_data"):
85
+ try:
86
+ right = np.asarray(obs.get_low_dim_data("right"), dtype=np.float32)
87
+ left = np.asarray(obs.get_low_dim_data("left"), dtype=np.float32)
88
+ except Exception:
89
+ right = np.asarray(obs.get_low_dim_data(getattr(obs, "right")), dtype=np.float32)
90
+ left = np.asarray(obs.get_low_dim_data(getattr(obs, "left")), dtype=np.float32)
91
+ else:
92
+ right = np.asarray(getattr(obs.right, "get_low_dim_data")(), dtype=np.float32)
93
+ left = np.asarray(getattr(obs.left, "get_low_dim_data")(), dtype=np.float32)
94
  proprio = np.concatenate([right, left], axis=0)
95
  if timestep is not None and episode_length and episode_length > 1:
96
  time_feature = np.array(
 
112
  camera_spec = camera_spec or RLBenchThreeCameraSpec()
113
  rgb = {camera: _camera_rgb(obs, camera) for camera in camera_spec.cameras}
114
  intrinsics = {
115
+ camera: _camera_misc(obs, camera, "intrinsics")
116
  for camera in camera_spec.cameras
117
  }
118
  extrinsics = {
119
+ camera: _camera_misc(obs, camera, "extrinsics")
120
  for camera in camera_spec.cameras
121
  }
122
  point_cloud = None
code/reveal_vla_bimanual/sim_rlbench/peract2_runner.py CHANGED
@@ -29,22 +29,38 @@ def _default_nvidia_shim_root() -> Path | None:
29
  return candidate if candidate.exists() else None
30
 
31
 
 
 
 
 
 
 
 
 
 
32
  @dataclass
33
  class BenchmarkRunSpec:
34
  upstream_root: Path = Path("/workspace/third_party/peract_bimanual")
35
  demo_path: Path = Path("/workspace/data/rlbench2")
36
  replay_path: Path = Path("/workspace/replays/rlbench2")
37
  logdir: Path = Path("/workspace/logs/rlbench2")
38
- method: str = "BIMANUAL_PERACT"
39
  tasks: tuple[str, ...] = field(default_factory=lambda: PERACT2_BIMANUAL_TASKS)
40
  demos: int = 100
41
  training_iterations: int = 40000
42
  seed: int = 0
43
  gpu: int = 0
44
  display: str = ":99"
 
 
45
  coppeliasim_root: Path = Path("/workspace/assets/coppeliasim_v4_1_0")
46
  camera_spec: RLBenchThreeCameraSpec = field(default_factory=default_three_camera_spec)
47
 
 
 
 
 
 
48
  def common_overrides(self) -> list[str]:
49
  task_name = "multi_3cam" if len(self.tasks) > 1 else self.tasks[0]
50
  overrides = [
@@ -66,7 +82,7 @@ class BenchmarkRunSpec:
66
 
67
  def train_command(self, python_executable: str | None = None) -> list[str]:
68
  python_executable = python_executable or sys.executable
69
- return [python_executable, "train.py", *self.common_overrides()]
70
 
71
  def eval_command(
72
  self,
@@ -76,7 +92,7 @@ class BenchmarkRunSpec:
76
  python_executable: str | None = None,
77
  ) -> list[str]:
78
  python_executable = python_executable or sys.executable
79
- return [
80
  python_executable,
81
  "eval.py",
82
  f"method={self.method}",
@@ -85,7 +101,7 @@ class BenchmarkRunSpec:
85
  f"eval_episodes={episodes}",
86
  f"cinematic_recorder.enabled={str(save_videos)}",
87
  *self.camera_spec.hydra_overrides(),
88
- ]
89
 
90
  def env(self) -> dict[str, str]:
91
  env = os.environ.copy()
@@ -117,7 +133,7 @@ class BenchmarkRunSpec:
117
  def run_train(self) -> subprocess.CompletedProcess[bytes]:
118
  return subprocess.run(
119
  self.train_command(),
120
- cwd=self.upstream_root,
121
  env=self.env(),
122
  check=True,
123
  )
 
29
  return candidate if candidate.exists() else None
30
 
31
 
32
+ def resolve_upstream_root(upstream_root: Path) -> Path:
33
+ if (upstream_root / "train.py").exists():
34
+ return upstream_root
35
+ nested_root = upstream_root / "peract"
36
+ if (nested_root / "train.py").exists():
37
+ return nested_root
38
+ return upstream_root
39
+
40
+
41
  @dataclass
42
  class BenchmarkRunSpec:
43
  upstream_root: Path = Path("/workspace/third_party/peract_bimanual")
44
  demo_path: Path = Path("/workspace/data/rlbench2")
45
  replay_path: Path = Path("/workspace/replays/rlbench2")
46
  logdir: Path = Path("/workspace/logs/rlbench2")
47
+ method: str = "PERACT_BC"
48
  tasks: tuple[str, ...] = field(default_factory=lambda: PERACT2_BIMANUAL_TASKS)
49
  demos: int = 100
50
  training_iterations: int = 40000
51
  seed: int = 0
52
  gpu: int = 0
53
  display: str = ":99"
54
+ use_xvfb: bool = True
55
+ xvfb_screen: str = "1280x1024x24"
56
  coppeliasim_root: Path = Path("/workspace/assets/coppeliasim_v4_1_0")
57
  camera_spec: RLBenchThreeCameraSpec = field(default_factory=default_three_camera_spec)
58
 
59
+ def _wrap_display(self, command: list[str]) -> list[str]:
60
+ if not self.use_xvfb:
61
+ return command
62
+ return ["xvfb-run", "-a", "-s", f"-screen 0 {self.xvfb_screen}", *command]
63
+
64
  def common_overrides(self) -> list[str]:
65
  task_name = "multi_3cam" if len(self.tasks) > 1 else self.tasks[0]
66
  overrides = [
 
82
 
83
  def train_command(self, python_executable: str | None = None) -> list[str]:
84
  python_executable = python_executable or sys.executable
85
+ return self._wrap_display([python_executable, "train.py", *self.common_overrides()])
86
 
87
  def eval_command(
88
  self,
 
92
  python_executable: str | None = None,
93
  ) -> list[str]:
94
  python_executable = python_executable or sys.executable
95
+ return self._wrap_display([
96
  python_executable,
97
  "eval.py",
98
  f"method={self.method}",
 
101
  f"eval_episodes={episodes}",
102
  f"cinematic_recorder.enabled={str(save_videos)}",
103
  *self.camera_spec.hydra_overrides(),
104
+ ])
105
 
106
  def env(self) -> dict[str, str]:
107
  env = os.environ.copy()
 
133
  def run_train(self) -> subprocess.CompletedProcess[bytes]:
134
  return subprocess.run(
135
  self.train_command(),
136
+ cwd=resolve_upstream_root(self.upstream_root),
137
  env=self.env(),
138
  check=True,
139
  )
code/reveal_vla_bimanual/sim_rlbench/smoke_test.py CHANGED
@@ -2,10 +2,11 @@ from __future__ import annotations
2
 
3
  import argparse
4
  import json
 
5
  from pathlib import Path
6
 
7
  from sim_rlbench.camera_spec import default_three_camera_spec
8
- from sim_rlbench.peract2_runner import BenchmarkRunSpec
9
 
10
 
11
  def main() -> None:
@@ -20,6 +21,7 @@ def main() -> None:
20
  "camera_names": list(spec.cameras),
21
  "resolution": list(spec.resolution),
22
  "global_camera": spec.global_camera,
 
23
  }
24
 
25
  import_status = {}
@@ -39,7 +41,7 @@ def main() -> None:
39
  demo_path=Path(args.demo_path),
40
  camera_spec=spec,
41
  )
42
- print(" ".join(run_spec.train_command()))
43
 
44
 
45
  if __name__ == "__main__":
 
2
 
3
  import argparse
4
  import json
5
+ import shlex
6
  from pathlib import Path
7
 
8
  from sim_rlbench.camera_spec import default_three_camera_spec
9
+ from sim_rlbench.peract2_runner import BenchmarkRunSpec, resolve_upstream_root
10
 
11
 
12
  def main() -> None:
 
21
  "camera_names": list(spec.cameras),
22
  "resolution": list(spec.resolution),
23
  "global_camera": spec.global_camera,
24
+ "resolved_upstream_root": str(resolve_upstream_root(Path(args.upstream_root))),
25
  }
26
 
27
  import_status = {}
 
41
  demo_path=Path(args.demo_path),
42
  camera_spec=spec,
43
  )
44
+ print(shlex.join(run_spec.train_command()))
45
 
46
 
47
  if __name__ == "__main__":