xinjie.wang commited on
Commit
7ecea11
Β·
1 Parent(s): 808531b
common.py CHANGED
@@ -189,7 +189,7 @@ os.makedirs(TMP_DIR, exist_ok=True)
189
  lighting_css = """
190
  <style>
191
  #lighter_mesh canvas {
192
- filter: brightness(1.8) !important;
193
  }
194
  </style>
195
  """
@@ -547,7 +547,9 @@ def extract_urdf(
547
 
548
  # Convert to URDF and recover attrs by GPT.
549
  filename = "sample"
550
- urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
 
 
551
  asset_attrs = {
552
  "version": VERSION,
553
  "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
 
189
  lighting_css = """
190
  <style>
191
  #lighter_mesh canvas {
192
+ filter: brightness(1.9) !important;
193
  }
194
  </style>
195
  """
 
547
 
548
  # Convert to URDF and recover attrs by GPT.
549
  filename = "sample"
550
+ urdf_convertor = URDFGenerator(
551
+ GPT_CLIENT, render_view_num=4, decompose_convex=True
552
+ )
553
  asset_attrs = {
554
  "version": VERSION,
555
  "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
embodied_gen/data/convex_decomposer.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import logging
18
+ import multiprocessing as mp
19
+ import os
20
+
21
+ import coacd
22
+ import trimesh
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ __all__ = [
27
+ "decompose_convex_coacd",
28
+ "decompose_convex_mesh",
29
+ "decompose_convex_process",
30
+ ]
31
+
32
+
33
+ def decompose_convex_coacd(
34
+ filename: str, outfile: str, params: dict, verbose: bool = False
35
+ ) -> None:
36
+ coacd.set_log_level("info" if verbose else "warn")
37
+
38
+ mesh = trimesh.load(filename, force="mesh")
39
+ mesh = coacd.Mesh(mesh.vertices, mesh.faces)
40
+
41
+ result = coacd.run_coacd(mesh, **params)
42
+ combined = sum([trimesh.Trimesh(*m) for m in result])
43
+ combined.export(outfile)
44
+
45
+
46
+ def decompose_convex_mesh(
47
+ filename: str,
48
+ outfile: str,
49
+ threshold: float = 0.05,
50
+ max_convex_hull: int = -1,
51
+ preprocess_mode: str = "auto",
52
+ preprocess_resolution: int = 30,
53
+ resolution: int = 2000,
54
+ mcts_nodes: int = 20,
55
+ mcts_iterations: int = 150,
56
+ mcts_max_depth: int = 3,
57
+ pca: bool = False,
58
+ merge: bool = True,
59
+ seed: int = 0,
60
+ verbose: bool = False,
61
+ ) -> str:
62
+ """Decompose a mesh into convex parts using the CoACD algorithm."""
63
+ coacd.set_log_level("info" if verbose else "warn")
64
+
65
+ if os.path.exists(outfile):
66
+ logger.warning(f"Output file {outfile} already exists, removing it.")
67
+ os.remove(outfile)
68
+
69
+ params = dict(
70
+ threshold=threshold,
71
+ max_convex_hull=max_convex_hull,
72
+ preprocess_mode=preprocess_mode,
73
+ preprocess_resolution=preprocess_resolution,
74
+ resolution=resolution,
75
+ mcts_nodes=mcts_nodes,
76
+ mcts_iterations=mcts_iterations,
77
+ mcts_max_depth=mcts_max_depth,
78
+ pca=pca,
79
+ merge=merge,
80
+ seed=seed,
81
+ )
82
+
83
+ try:
84
+ decompose_convex_coacd(filename, outfile, params, verbose)
85
+ if os.path.exists(outfile):
86
+ return outfile
87
+ except Exception as e:
88
+ if verbose:
89
+ print(f"Decompose convex first attempt failed: {e}.")
90
+
91
+ if preprocess_mode != "on":
92
+ try:
93
+ params["preprocess_mode"] = "on"
94
+ decompose_convex_coacd(filename, outfile, params, verbose)
95
+ if os.path.exists(outfile):
96
+ return outfile
97
+ except Exception as e:
98
+ if verbose:
99
+ print(
100
+ f"Decompose convex second attempt with preprocess_mode='on' failed: {e}"
101
+ )
102
+
103
+ raise RuntimeError(f"Convex decomposition failed on {filename}")
104
+
105
+
106
+ def decompose_convex_mp(
107
+ filename: str,
108
+ outfile: str,
109
+ threshold: float = 0.05,
110
+ max_convex_hull: int = -1,
111
+ preprocess_mode: str = "auto",
112
+ preprocess_resolution: int = 30,
113
+ resolution: int = 2000,
114
+ mcts_nodes: int = 20,
115
+ mcts_iterations: int = 150,
116
+ mcts_max_depth: int = 3,
117
+ pca: bool = False,
118
+ merge: bool = True,
119
+ seed: int = 0,
120
+ verbose: bool = False,
121
+ ) -> str:
122
+ """Decompose a mesh into convex parts using the CoACD algorithm in a separate process.
123
+
124
+ See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
125
+ """
126
+ params = dict(
127
+ threshold=threshold,
128
+ max_convex_hull=max_convex_hull,
129
+ preprocess_mode=preprocess_mode,
130
+ preprocess_resolution=preprocess_resolution,
131
+ resolution=resolution,
132
+ mcts_nodes=mcts_nodes,
133
+ mcts_iterations=mcts_iterations,
134
+ mcts_max_depth=mcts_max_depth,
135
+ pca=pca,
136
+ merge=merge,
137
+ seed=seed,
138
+ )
139
+
140
+ ctx = mp.get_context("spawn")
141
+ p = ctx.Process(
142
+ target=decompose_convex_coacd,
143
+ args=(filename, outfile, params, verbose),
144
+ )
145
+ p.start()
146
+ p.join()
147
+ if p.exitcode == 0 and os.path.exists(outfile):
148
+ return outfile
149
+
150
+ if preprocess_mode != "on":
151
+ params["preprocess_mode"] = "on"
152
+ p = ctx.Process(
153
+ target=decompose_convex_coacd,
154
+ args=(filename, outfile, params, verbose),
155
+ )
156
+ p.start()
157
+ p.join()
158
+ if p.exitcode == 0 and os.path.exists(outfile):
159
+ return outfile
160
+
161
+ raise RuntimeError(f"Convex decomposition failed on {filename}")
embodied_gen/data/mesh_operator.py CHANGED
@@ -16,13 +16,17 @@
16
 
17
 
18
  import logging
 
 
19
  from typing import Tuple, Union
20
 
 
21
  import igraph
22
  import numpy as np
23
  import pyvista as pv
24
  import spaces
25
  import torch
 
26
  import utils3d
27
  from pymeshfix import _meshfix
28
  from tqdm import tqdm
@@ -33,7 +37,9 @@ logging.basicConfig(
33
  logger = logging.getLogger(__name__)
34
 
35
 
36
- __all__ = ["MeshFixer"]
 
 
37
 
38
 
39
  def _radical_inverse(base, n):
 
16
 
17
 
18
  import logging
19
+ import multiprocessing as mp
20
+ import os
21
  from typing import Tuple, Union
22
 
23
+ import coacd
24
  import igraph
25
  import numpy as np
26
  import pyvista as pv
27
  import spaces
28
  import torch
29
+ import trimesh
30
  import utils3d
31
  from pymeshfix import _meshfix
32
  from tqdm import tqdm
 
37
  logger = logging.getLogger(__name__)
38
 
39
 
40
+ __all__ = [
41
+ "MeshFixer",
42
+ ]
43
 
44
 
45
  def _radical_inverse(base, n):
embodied_gen/envs/pick_embodiedgen.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ from copy import deepcopy
20
+
21
+ import numpy as np
22
+ import sapien
23
+ import torch
24
+ import torchvision.transforms as transforms
25
+ from mani_skill.envs.sapien_env import BaseEnv
26
+ from mani_skill.sensors.camera import CameraConfig
27
+ from mani_skill.utils import sapien_utils
28
+ from mani_skill.utils.building import actors
29
+ from mani_skill.utils.registration import register_env
30
+ from mani_skill.utils.structs.actor import Actor
31
+ from mani_skill.utils.structs.pose import Pose
32
+ from mani_skill.utils.structs.types import (
33
+ GPUMemoryConfig,
34
+ SceneConfig,
35
+ SimConfig,
36
+ )
37
+ from mani_skill.utils.visualization.misc import tile_images
38
+ from tqdm import tqdm
39
+ from embodied_gen.models.gs_model import GaussianOperator
40
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
41
+ from embodied_gen.utils.geometry import bfs_placement, quaternion_multiply
42
+ from embodied_gen.utils.log import logger
43
+ from embodied_gen.utils.process_media import alpha_blend_rgba
44
+ from embodied_gen.utils.simulation import (
45
+ SIM_COORD_ALIGN,
46
+ load_assets_from_layout_file,
47
+ )
48
+
49
+ __all__ = ["PickEmbodiedGen"]
50
+
51
+
52
+ @register_env("PickEmbodiedGen-v1", max_episode_steps=100)
53
+ class PickEmbodiedGen(BaseEnv):
54
+ SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
55
+ goal_thresh = 0.0
56
+
57
+ def __init__(
58
+ self,
59
+ *args,
60
+ robot_uids: str | list[str] = "panda",
61
+ robot_init_qpos_noise: float = 0.02,
62
+ num_envs: int = 1,
63
+ reconfiguration_freq: int = None,
64
+ **kwargs,
65
+ ):
66
+ self.robot_init_qpos_noise = robot_init_qpos_noise
67
+ if reconfiguration_freq is None:
68
+ if num_envs == 1:
69
+ reconfiguration_freq = 1
70
+ else:
71
+ reconfiguration_freq = 0
72
+
73
+ # Init params from kwargs.
74
+ layout_file = kwargs.pop("layout_file", None)
75
+ replace_objs = kwargs.pop("replace_objs", True)
76
+ self.enable_grasp = kwargs.pop("enable_grasp", False)
77
+ self.init_quat = kwargs.pop("init_quat", [0.7071, 0, 0, 0.7071])
78
+ # Add small offset in z-axis to avoid collision.
79
+ self.objs_z_offset = kwargs.pop("objs_z_offset", 0.002)
80
+ self.robot_z_offset = kwargs.pop("robot_z_offset", 0.002)
81
+
82
+ self.layouts = self.init_env_layouts(
83
+ layout_file, num_envs, replace_objs
84
+ )
85
+ self.robot_pose = self.compute_robot_init_pose(
86
+ self.layouts, num_envs, self.robot_z_offset
87
+ )
88
+ self.env_actors = dict()
89
+ self.image_transform = transforms.PILToTensor()
90
+
91
+ super().__init__(
92
+ *args,
93
+ robot_uids=robot_uids,
94
+ reconfiguration_freq=reconfiguration_freq,
95
+ num_envs=num_envs,
96
+ **kwargs,
97
+ )
98
+
99
+ self.bg_images = dict()
100
+ if self.render_mode == "hybrid":
101
+ self.bg_images = self.render_gs3d_images(
102
+ self.layouts, num_envs, self.init_quat
103
+ )
104
+
105
+ @staticmethod
106
+ def init_env_layouts(
107
+ layout_file: str, num_envs: int, replace_objs: bool
108
+ ) -> list[LayoutInfo]:
109
+ layout = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
110
+ layouts = []
111
+ for env_idx in range(num_envs):
112
+ if replace_objs and env_idx > 0:
113
+ layout = bfs_placement(deepcopy(layout))
114
+ layouts.append(layout)
115
+
116
+ return layouts
117
+
118
+ @staticmethod
119
+ def compute_robot_init_pose(
120
+ layouts: list[LayoutInfo], num_envs: int, z_offset: float = 0.0
121
+ ) -> list[list[float]]:
122
+ robot_pose = []
123
+ for env_idx in range(num_envs):
124
+ layout = layouts[env_idx]
125
+ robot_node = layout.relation[Scene3DItemEnum.ROBOT.value]
126
+ x, y, z, qx, qy, qz, qw = layout.position[robot_node]
127
+ robot_pose.append([x, y, z + z_offset, qw, qx, qy, qz])
128
+
129
+ return robot_pose
130
+
131
+ @property
132
+ def _default_sim_config(self):
133
+ return SimConfig(
134
+ scene_config=SceneConfig(
135
+ solver_position_iterations=30,
136
+ # contact_offset=0.04,
137
+ # rest_offset=0.001,
138
+ ),
139
+ # sim_freq=200,
140
+ control_freq=50,
141
+ gpu_memory_config=GPUMemoryConfig(
142
+ max_rigid_contact_count=2**20, max_rigid_patch_count=2**19
143
+ ),
144
+ )
145
+
146
+ @property
147
+ def _default_sensor_configs(self):
148
+ pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
149
+
150
+ return [
151
+ CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)
152
+ ]
153
+
154
+ @property
155
+ def _default_human_render_camera_configs(self):
156
+ pose = sapien_utils.look_at(
157
+ eye=[0.9, 0.0, 1.1], target=[0.0, 0.0, 0.9]
158
+ )
159
+
160
+ return CameraConfig(
161
+ "render_camera", pose, 256, 256, np.deg2rad(75), 0.01, 100
162
+ )
163
+
164
+ def _load_agent(self, options: dict):
165
+ super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
166
+
167
+ def _load_scene(self, options: dict):
168
+ all_objects = []
169
+ logger.info(f"Loading assets and decomposition mesh collisions...")
170
+ for env_idx in range(self.num_envs):
171
+ env_actors = load_assets_from_layout_file(
172
+ self.scene,
173
+ self.layouts[env_idx],
174
+ z_offset=self.objs_z_offset,
175
+ init_quat=self.init_quat,
176
+ env_idx=env_idx,
177
+ )
178
+ self.env_actors[f"env{env_idx}"] = env_actors
179
+ all_objects.extend(env_actors.values())
180
+
181
+ self.obj = all_objects[-1]
182
+ for obj in all_objects:
183
+ self.remove_from_state_dict_registry(obj)
184
+
185
+ self.all_objects = Actor.merge(all_objects, name="all_objects")
186
+ self.add_to_state_dict_registry(self.all_objects)
187
+
188
+ self.goal_site = actors.build_sphere(
189
+ self.scene,
190
+ radius=self.goal_thresh,
191
+ color=[0, 1, 0, 0],
192
+ name="goal_site",
193
+ body_type="kinematic",
194
+ add_collision=False,
195
+ initial_pose=sapien.Pose(),
196
+ )
197
+ self._hidden_objects.append(self.goal_site)
198
+
199
+ def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
200
+ with torch.device(self.device):
201
+ b = len(env_idx)
202
+ goal_xyz = torch.zeros((b, 3))
203
+ goal_xyz[:, :2] = torch.rand((b, 2)) * 0.2 - 0.1
204
+ self.goal_site.set_pose(Pose.create_from_pq(goal_xyz))
205
+
206
+ qpos = np.array(
207
+ [
208
+ 0.0,
209
+ np.pi / 8,
210
+ 0,
211
+ -np.pi * 3 / 8,
212
+ 0,
213
+ np.pi * 3 / 4,
214
+ np.pi / 4,
215
+ 0.04,
216
+ 0.04,
217
+ ]
218
+ )
219
+ qpos = (
220
+ np.random.normal(
221
+ 0, self.robot_init_qpos_noise, (self.num_envs, len(qpos))
222
+ )
223
+ + qpos
224
+ )
225
+ qpos[:, -2:] = 0.04
226
+ self.agent.robot.set_root_pose(np.array(self.robot_pose))
227
+ self.agent.reset(qpos)
228
+ self.agent.init_qpos = qpos
229
+ self.agent.controller.controllers["gripper"].reset()
230
+
231
+ def render_gs3d_images(
232
+ self, layouts: list[LayoutInfo], num_envs: int, init_quat: list[float]
233
+ ) -> dict[str, np.ndarray]:
234
+ sim_coord_align = (
235
+ torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
236
+ )
237
+ cameras = self.scene.sensors.copy()
238
+ cameras.update(self.scene.human_render_cameras)
239
+
240
+ bg_node = layouts[0].relation[Scene3DItemEnum.BACKGROUND.value]
241
+ gs_path = os.path.join(layouts[0].assets[bg_node], "gs_model.ply")
242
+ raw_gs: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
243
+ bg_images = dict()
244
+ for env_idx in tqdm(range(num_envs), desc="Pre-rendering Background"):
245
+ layout = layouts[env_idx]
246
+ x, y, z, qx, qy, qz, qw = layout.position[bg_node]
247
+ qx, qy, qz, qw = quaternion_multiply([qx, qy, qz, qw], init_quat)
248
+ init_pose = torch.tensor([x, y, z, qx, qy, qz, qw])
249
+ gs_model = raw_gs.get_gaussians(instance_pose=init_pose)
250
+ for key in cameras:
251
+ camera = cameras[key]
252
+ Ks = camera.camera.get_intrinsic_matrix() # (n_env, 3, 3)
253
+ c2w = camera.camera.get_model_matrix() # (n_env, 4, 4)
254
+ result = gs_model.render(
255
+ c2w[env_idx] @ sim_coord_align,
256
+ Ks[env_idx],
257
+ image_width=camera.config.width,
258
+ image_height=camera.config.height,
259
+ )
260
+ bg_images[f"{key}-env{env_idx}"] = result.rgb[..., ::-1]
261
+
262
+ return bg_images
263
+
264
+ def render(self):
265
+ if self.render_mode is None:
266
+ raise RuntimeError("render_mode is not set.")
267
+ if self.render_mode == "human":
268
+ return self.render_human()
269
+ elif self.render_mode == "rgb_array":
270
+ res = self.render_rgb_array()
271
+ return res
272
+ elif self.render_mode == "sensors":
273
+ res = self.render_sensors()
274
+ return res
275
+ elif self.render_mode == "all":
276
+ return self.render_all()
277
+ elif self.render_mode == "hybrid":
278
+ return self.hybrid_render()
279
+ else:
280
+ raise NotImplementedError(
281
+ f"Unsupported render mode {self.render_mode}."
282
+ )
283
+
284
+ def render_rgb_array(
285
+ self, camera_name: str = None, return_alpha: bool = False
286
+ ):
287
+ for obj in self._hidden_objects:
288
+ obj.show_visual()
289
+ self.scene.update_render(
290
+ update_sensors=False, update_human_render_cameras=True
291
+ )
292
+ images = []
293
+ render_images = self.scene.get_human_render_camera_images(
294
+ camera_name, return_alpha
295
+ )
296
+ for image in render_images.values():
297
+ images.append(image)
298
+ if len(images) == 0:
299
+ return None
300
+ if len(images) == 1:
301
+ return images[0]
302
+ for obj in self._hidden_objects:
303
+ obj.hide_visual()
304
+ return tile_images(images)
305
+
306
+ def render_sensors(self):
307
+ images = []
308
+ sensor_images = self.get_sensor_images()
309
+ for image in sensor_images.values():
310
+ for img in image.values():
311
+ images.append(img)
312
+ return tile_images(images)
313
+
314
+ def hybrid_render(self):
315
+ fg_images = self.render_rgb_array(
316
+ return_alpha=True
317
+ ) # (n_env, h, w, 3)
318
+ images = []
319
+ for key in self.bg_images:
320
+ if "render_camera" not in key:
321
+ continue
322
+ env_idx = int(key.split("-env")[-1])
323
+ rgba = alpha_blend_rgba(
324
+ fg_images[env_idx].cpu().numpy(), self.bg_images[key]
325
+ )
326
+ images.append(self.image_transform(rgba))
327
+
328
+ images = torch.stack(images, dim=0)
329
+ images = images.permute(0, 2, 3, 1)
330
+
331
+ return images[..., :3]
332
+
333
+ def evaluate(self):
334
+ obj_to_goal_pos = (
335
+ self.obj.pose.p
336
+ ) # self.goal_site.pose.p - self.obj.pose.p
337
+ is_obj_placed = (
338
+ torch.linalg.norm(obj_to_goal_pos, axis=1) <= self.goal_thresh
339
+ )
340
+ is_grasped = self.agent.is_grasping(self.obj)
341
+ is_robot_static = self.agent.is_static(0.2)
342
+
343
+ return dict(
344
+ is_grasped=is_grasped,
345
+ obj_to_goal_pos=obj_to_goal_pos,
346
+ is_obj_placed=is_obj_placed,
347
+ is_robot_static=is_robot_static,
348
+ is_grasping=self.agent.is_grasping(self.obj),
349
+ success=torch.logical_and(is_obj_placed, is_robot_static),
350
+ )
351
+
352
+ def _get_obs_extra(self, info: dict):
353
+
354
+ return dict()
355
+
356
+ def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
357
+ tcp_to_obj_dist = torch.linalg.norm(
358
+ self.obj.pose.p - self.agent.tcp.pose.p, axis=1
359
+ )
360
+ reaching_reward = 1 - torch.tanh(5 * tcp_to_obj_dist)
361
+ reward = reaching_reward
362
+
363
+ is_grasped = info["is_grasped"]
364
+ reward += is_grasped
365
+
366
+ # obj_to_goal_dist = torch.linalg.norm(
367
+ # self.goal_site.pose.p - self.obj.pose.p, axis=1
368
+ # )
369
+ obj_to_goal_dist = torch.linalg.norm(
370
+ self.obj.pose.p - self.obj.pose.p, axis=1
371
+ )
372
+ place_reward = 1 - torch.tanh(5 * obj_to_goal_dist)
373
+ reward += place_reward * is_grasped
374
+
375
+ reward += info["is_obj_placed"] * is_grasped
376
+
377
+ static_reward = 1 - torch.tanh(
378
+ 5
379
+ * torch.linalg.norm(self.agent.robot.get_qvel()[..., :-2], axis=1)
380
+ )
381
+ reward += static_reward * info["is_obj_placed"] * is_grasped
382
+
383
+ reward[info["success"]] = 6
384
+ return reward
385
+
386
+ def compute_normalized_dense_reward(
387
+ self, obs: any, action: torch.Tensor, info: dict
388
+ ):
389
+ return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
embodied_gen/models/gs_model.py CHANGED
@@ -51,17 +51,15 @@ class RenderResult:
51
 
52
  def __post_init__(self):
53
  if isinstance(self.rgb, torch.Tensor):
54
- rgb = self.rgb.detach().cpu().numpy()
55
- rgb = (rgb * 255).astype(np.uint8)
56
- self.rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
57
  if isinstance(self.depth, torch.Tensor):
58
- self.depth = self.depth.detach().cpu().numpy()
59
  if isinstance(self.opacity, torch.Tensor):
60
- opacity = self.opacity.detach().cpu().numpy()
61
- opacity = (opacity * 255).astype(np.uint8)
62
- self.opacity = cv2.cvtColor(opacity, cv2.COLOR_GRAY2RGB)
63
  mask = np.where(self.opacity > self.mask_threshold, 255, 0)
64
- self.mask = mask[..., 0:1].astype(np.uint8)
65
  self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
66
 
67
 
 
51
 
52
  def __post_init__(self):
53
  if isinstance(self.rgb, torch.Tensor):
54
+ rgb = (self.rgb * 255).to(torch.uint8)
55
+ self.rgb = rgb.cpu().numpy()[..., ::-1]
 
56
  if isinstance(self.depth, torch.Tensor):
57
+ self.depth = self.depth.cpu().numpy()
58
  if isinstance(self.opacity, torch.Tensor):
59
+ opacity = (self.opacity * 255).to(torch.uint8)
60
+ self.opacity = opacity.cpu().numpy()
 
61
  mask = np.where(self.opacity > self.mask_threshold, 255, 0)
62
+ self.mask = mask.astype(np.uint8)
63
  self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
64
 
65
 
embodied_gen/models/layout.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import json
20
+ import logging
21
+ import os
22
+ import re
23
+
24
+ import json_repair
25
+ from embodied_gen.utils.enum import (
26
+ LayoutInfo,
27
+ RobotItemEnum,
28
+ Scene3DItemEnum,
29
+ SpatialRelationEnum,
30
+ )
31
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
32
+ from embodied_gen.utils.process_media import SceneTreeVisualizer
33
+
34
+ logging.basicConfig(level=logging.INFO)
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ __all__ = [
39
+ "LayoutDesigner",
40
+ "LAYOUT_DISASSEMBLER",
41
+ "LAYOUT_GRAPHER",
42
+ "LAYOUT_DESCRIBER",
43
+ ]
44
+
45
+
46
+ DISTRACTOR_NUM = 3 # Maximum number of distractor objects allowed
47
+ LAYOUT_DISASSEMBLE_PROMPT = f"""
48
+ You are an intelligent 3D scene planner. Given a natural language
49
+ description of a robotic task, output a structured description of
50
+ an interactive 3D scene.
51
+
52
+ The output must include the following fields:
53
+ - task: A high-level task type (e.g., "single-arm pick",
54
+ "dual-arm grasping", "pick and place", "object sorting").
55
+ - {Scene3DItemEnum.ROBOT}: The name or type of robot involved. If not mentioned,
56
+ use {RobotItemEnum.FRANKA} as default.
57
+ - {Scene3DItemEnum.BACKGROUND}: The room or indoor environment where the task happens
58
+ (e.g., Kitchen, Bedroom, Living Room, Workshop, Office).
59
+ - {Scene3DItemEnum.CONTEXT}: A indoor object involved in the manipulation
60
+ (e.g., Table, Shelf, Desk, Bed, Cabinet).
61
+ - {Scene3DItemEnum.MANIPULATED_OBJS}: The main object(s) that the robot directly interacts with.
62
+ - {Scene3DItemEnum.DISTRACTOR_OBJS}: Other objects that naturally belong to the scene but are not part of the main task.
63
+
64
+ Constraints:
65
+ - The {Scene3DItemEnum.BACKGROUND} must logically match the described task.
66
+ - The {Scene3DItemEnum.CONTEXT} must fit within the {Scene3DItemEnum.BACKGROUND}. (e.g., a bedroom may include a table or bed, but not a workbench.)
67
+ - The {Scene3DItemEnum.CONTEXT} must be a concrete indoor object, such as a "table",
68
+ "shelf", "desk", or "bed". It must not be an abstract concept (e.g., "area", "space", "zone")
69
+ or structural surface (e.g., "floor", "ground"). If the input describes an interaction near
70
+ the floor or vague space, you must infer a plausible object like a "table", "cabinet", or "storage box" instead.
71
+ - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} objects must be plausible,
72
+ and semantically compatible with the {Scene3DItemEnum.CONTEXT} and {Scene3DItemEnum.BACKGROUND}.
73
+ - {Scene3DItemEnum.DISTRACTOR_OBJS} must not confuse or overlap with the manipulated objects.
74
+ - {Scene3DItemEnum.DISTRACTOR_OBJS} number limit: {DISTRACTOR_NUM} distractors maximum.
75
+ - All {Scene3DItemEnum.BACKGROUND} are limited to indoor environments.
76
+ - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} are rigid bodies and not include flexible objects.
77
+ - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} must be common
78
+ household or office items or furniture, not abstract concepts, not too small like needle.
79
+ - If the input includes a plural or grouped object (e.g., "pens", "bottles", "plates", "fruit"),
80
+ you must decompose it into multiple individual instances (e.g., ["pen", "pen"], ["apple", "pear"]).
81
+ - Containers that hold objects (e.g., "bowl of apples", "box of tools") must
82
+ be separated into individual items (e.g., ["bowl", "apple", "apple"]).
83
+ - Do not include transparent objects such as "glass", "plastic", etc.
84
+ - The output must be in compact JSON format and use Markdown syntax, just like the output in the example below.
85
+
86
+ Examples:
87
+
88
+ Input:
89
+ "Pick up the marker from the table and put it in the bowl robot {RobotItemEnum.UR5}."
90
+ Output:
91
+ ```json
92
+ {{
93
+ "task_desc": "Pick up the marker from the table and put it in the bowl.",
94
+ "task": "pick and place",
95
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.UR5}",
96
+ "{Scene3DItemEnum.BACKGROUND}": "kitchen",
97
+ "{Scene3DItemEnum.CONTEXT}": "table",
98
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["marker"],
99
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["mug", "notebook", "bowl"]
100
+ }}
101
+ ```
102
+
103
+ Input:
104
+ "Put the rubik's cube on the top of the shelf."
105
+ Output:
106
+ ```json
107
+ {{
108
+ "task_desc": "Put the rubik's cube on the top of the shelf.",
109
+ "task": "pick and place",
110
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.FRANKA}",
111
+ "{Scene3DItemEnum.BACKGROUND}": "bedroom",
112
+ "{Scene3DItemEnum.CONTEXT}": "shelf",
113
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["rubik's cube"],
114
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["pen", "cup", "toy car"]
115
+ }}
116
+ ```
117
+
118
+ Input:
119
+ "Remove all the objects from the white basket and put them on the table."
120
+ Output:
121
+ ```json
122
+ {{
123
+ "task_desc": "Remove all the objects from the white basket and put them on the table, robot {RobotItemEnum.PIPER}.",
124
+ "task": "pick and place",
125
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.PIPER}",
126
+ "{Scene3DItemEnum.BACKGROUND}": "office",
127
+ "{Scene3DItemEnum.CONTEXT}": "table",
128
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["banana", "mobile phone"],
129
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["plate", "white basket"]
130
+ }}
131
+ ```
132
+
133
+ Input:
134
+ "Pick up the rope on the chair and put it in the box."
135
+ Output:
136
+ ```json
137
+ {{
138
+ "task_desc": "Pick up the rope on the chair and put it in the box, robot {RobotItemEnum.FRANKA}.",
139
+ "task": "pick and place",
140
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.FRANKA}",
141
+ "{Scene3DItemEnum.BACKGROUND}": "living room",
142
+ "{Scene3DItemEnum.CONTEXT}": "chair",
143
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["rope", "box"],
144
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["magazine"]
145
+ }}
146
+ ```
147
+
148
+ Input:
149
+ "Pick up the seal tape and plastic from the counter and put them in the open drawer and close it."
150
+ Output:
151
+ ```json
152
+ {{
153
+ "task_desc": "Pick up the seal tape and plastic from the counter and put them in the open drawer and close it.",
154
+ "task": "pick and place",
155
+ "robot": "franka",
156
+ "background": "kitchen",
157
+ "context": "counter",
158
+ "manipulated_objs": ["seal tape", "plastic", "opened drawer"],
159
+ "distractor_objs": ["scissors"]
160
+ }}
161
+ ```
162
+
163
+ Input:
164
+ "Put the pens in the grey bowl."
165
+ Output:
166
+ ```json
167
+ {{
168
+ "task_desc": "Put the pens in the grey bowl.",
169
+ "task": "pick and place",
170
+ "robot": "franka",
171
+ "background": "office",
172
+ "context": "table",
173
+ "manipulated_objs": ["pen", "pen", "grey bowl"],
174
+ "distractor_objs": ["notepad", "cup"]
175
+ }}
176
+ ```
177
+
178
+ """
179
+
180
+
181
+ LAYOUT_HIERARCHY_PROMPT = f"""
182
+ You are a 3D scene layout reasoning expert.
183
+ Your task is to generate a spatial relationship dictionary in multiway tree
184
+ that describes how objects are arranged in a 3D environment
185
+ based on a given task description and object list.
186
+
187
+ Input in JSON format containing the task description, task type,
188
+ {Scene3DItemEnum.ROBOT}, {Scene3DItemEnum.BACKGROUND}, {Scene3DItemEnum.CONTEXT},
189
+ and a list of objects, including {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS}.
190
+
191
+ ### Supported Spatial Relations:
192
+ - "{SpatialRelationEnum.ON}": The child object bottom is directly on top of the parent object top.
193
+ - "{SpatialRelationEnum.INSIDE}": The child object is inside the context object.
194
+ - "{SpatialRelationEnum.IN}": The {Scene3DItemEnum.ROBOT} in the {Scene3DItemEnum.BACKGROUND}.
195
+ - "{SpatialRelationEnum.FLOOR}": The child object bottom is on the floor of the {Scene3DItemEnum.BACKGROUND}.
196
+
197
+ ### Rules:
198
+ - The {Scene3DItemEnum.CONTEXT} object must be "{SpatialRelationEnum.FLOOR}" the {Scene3DItemEnum.BACKGROUND}.
199
+ - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} must be either
200
+ "{SpatialRelationEnum.ON}" or "{SpatialRelationEnum.INSIDE}" the {Scene3DItemEnum.CONTEXT}
201
+ - Or "{SpatialRelationEnum.FLOOR}" {Scene3DItemEnum.BACKGROUND}.
202
+ - Use "{SpatialRelationEnum.INSIDE}" only if the parent is a container-like object (e.g., shelf, rack, cabinet).
203
+ - Do not define relationship edges between objects, only for the child and parent nodes.
204
+ - {Scene3DItemEnum.ROBOT} must "{SpatialRelationEnum.IN}" the {Scene3DItemEnum.BACKGROUND}.
205
+ - Ensure that each object appears only once in the layout tree, and its spatial relationship is defined with only one parent.
206
+ - Ensure a valid multiway tree structure with a maximum depth of 2 levels suitable for a 3D scene layout representation.
207
+ - Only output the final output in JSON format, using Markdown syntax as in examples.
208
+
209
+ ### Example
210
+ Input:
211
+ {{
212
+ "task_desc": "Pick up the marker from the table and put it in the bowl.",
213
+ "task": "pick and place",
214
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.FRANKA}",
215
+ "{Scene3DItemEnum.BACKGROUND}": "kitchen",
216
+ "{Scene3DItemEnum.CONTEXT}": "table",
217
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["marker", "bowl"],
218
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["mug", "chair"]
219
+ }}
220
+ Intermediate Think:
221
+ table {SpatialRelationEnum.FLOOR} kitchen
222
+ chair {SpatialRelationEnum.FLOOR} kitchen
223
+ {RobotItemEnum.FRANKA} {SpatialRelationEnum.IN} kitchen
224
+ marker {SpatialRelationEnum.ON} table
225
+ bowl {SpatialRelationEnum.ON} table
226
+ mug {SpatialRelationEnum.ON} table
227
+ Final Output:
228
+ ```json
229
+ {{
230
+ "kitchen": [
231
+ ["table", "{SpatialRelationEnum.FLOOR}"],
232
+ ["chair", "{SpatialRelationEnum.FLOOR}"],
233
+ ["{RobotItemEnum.FRANKA}", "{SpatialRelationEnum.IN}"]
234
+ ],
235
+ "table": [
236
+ ["marker", "{SpatialRelationEnum.ON}"],
237
+ ["bowl", "{SpatialRelationEnum.ON}"],
238
+ ["mug", "{SpatialRelationEnum.ON}"]
239
+ ]
240
+ }}
241
+ ```
242
+
243
+ Input:
244
+ {{
245
+ "task_desc": "Put the marker on top of the book.",
246
+ "task": "pick and place",
247
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.UR5}",
248
+ "{Scene3DItemEnum.BACKGROUND}": "office",
249
+ "{Scene3DItemEnum.CONTEXT}": "desk",
250
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["marker", "book"],
251
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["pen holder", "notepad"]
252
+ }}
253
+ Intermediate Think:
254
+ desk {SpatialRelationEnum.FLOOR} office
255
+ {RobotItemEnum.UR5} {SpatialRelationEnum.IN} office
256
+ marker {SpatialRelationEnum.ON} desk
257
+ book {SpatialRelationEnum.ON} desk
258
+ pen holder {SpatialRelationEnum.ON} desk
259
+ notepad {SpatialRelationEnum.ON} desk
260
+ Final Output:
261
+ ```json
262
+ {{
263
+ "office": [
264
+ ["desk", "{SpatialRelationEnum.FLOOR}"],
265
+ ["{RobotItemEnum.UR5}", "{SpatialRelationEnum.IN}"]
266
+ ],
267
+ "desk": [
268
+ ["marker", "{SpatialRelationEnum.ON}"],
269
+ ["book", "{SpatialRelationEnum.ON}"],
270
+ ["pen holder", "{SpatialRelationEnum.ON}"],
271
+ ["notepad", "{SpatialRelationEnum.ON}"]
272
+ ]
273
+ }}
274
+ ```
275
+
276
+ Input:
277
+ {{
278
+ "task_desc": "Put the rubik's cube on the top of the shelf.",
279
+ "task": "pick and place",
280
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.UR5}",
281
+ "{Scene3DItemEnum.BACKGROUND}": "bedroom",
282
+ "{Scene3DItemEnum.CONTEXT}": "shelf",
283
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["rubik's cube"],
284
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["toy car", "pen"]
285
+ }}
286
+ Intermediate Think:
287
+ shelf {SpatialRelationEnum.FLOOR} bedroom
288
+ {RobotItemEnum.UR5} {SpatialRelationEnum.IN} bedroom
289
+ rubik's cube {SpatialRelationEnum.INSIDE} shelf
290
+ toy car {SpatialRelationEnum.INSIDE} shelf
291
+ pen {SpatialRelationEnum.INSIDE} shelf
292
+ Final Output:
293
+ ```json
294
+ {{
295
+ "bedroom": [
296
+ ["shelf", "{SpatialRelationEnum.FLOOR}"],
297
+ ["{RobotItemEnum.UR5}", "{SpatialRelationEnum.IN}"]
298
+ ],
299
+ "shelf": [
300
+ ["rubik's cube", "{SpatialRelationEnum.INSIDE}"],
301
+ ["toy car", "{SpatialRelationEnum.INSIDE}"],
302
+ ["pen", "{SpatialRelationEnum.INSIDE}"]
303
+ ]
304
+ }}
305
+ ```
306
+
307
+ Input:
308
+ {{
309
+ "task_desc": "Put the marker in the cup on the counter.",
310
+ "task": "pick and place",
311
+ "robot": "franka",
312
+ "background": "kitchen",
313
+ "context": "counter",
314
+ "manipulated_objs": ["marker", "cup"],
315
+ "distractor_objs": ["plate", "spoon"]
316
+ }}
317
+ Intermediate Think:
318
+ counter {SpatialRelationEnum.FLOOR} kitchen
319
+ {RobotItemEnum.FRANKA} {SpatialRelationEnum.IN} kitchen
320
+ marker {SpatialRelationEnum.ON} counter
321
+ cup {SpatialRelationEnum.ON} counter
322
+ plate {SpatialRelationEnum.ON} counter
323
+ spoon {SpatialRelationEnum.ON} counter
324
+ Final Output:
325
+ ```json
326
+ {{
327
+ "kitchen": [
328
+ ["counter", "{SpatialRelationEnum.FLOOR}"],
329
+ ["{RobotItemEnum.FRANKA}", "{SpatialRelationEnum.IN}"]
330
+ ],
331
+ "counter": [
332
+ ["marker", "{SpatialRelationEnum.ON}"],
333
+ ["cup", "{SpatialRelationEnum.ON}"],
334
+ ["plate", "{SpatialRelationEnum.ON}"],
335
+ ["spoon", "{SpatialRelationEnum.ON}"]
336
+ ]
337
+ }}
338
+ ```
339
+ """
340
+
341
+
342
+ LAYOUT_DESCRIBER_PROMPT = """
343
+ You are a 3D asset style descriptor.
344
+
345
+ Given a task description and a dictionary where the key is the object content and
346
+ the value is the object type, output a JSON dictionary with each object paired
347
+ with a concise, styled visual description suitable for 3D asset generation.
348
+
349
+ Generation Guidelines:
350
+ - For each object, brainstorm multiple style candidates before selecting the final
351
+ description. Vary phrasing, material, texture, color, and spatial details.
352
+ - Each description must be a maximum of 15 words, including color, style, materials.
353
+ - Descriptions should be visually grounded, specific, and reflect surface texture and structure.
354
+ - For objects marked as "context", explicitly mention the object is standalone, has an empty top.
355
+ - Use rich style descriptors: e.g., "scratched brown wooden desk" etc.
356
+ - Ensure all object styles align with the task's overall context and environment.
357
+
358
+ Format your output in JSON like the example below.
359
+
360
+ Example Input:
361
+ "Pick up the rope on the chair and put it in the box. {'living room': 'background', 'chair': 'context',
362
+ 'rope': 'manipulated_objs', 'box': 'manipulated_objs', 'magazine': 'distractor_objs'}"
363
+
364
+ Example Output:
365
+ ```json
366
+ {
367
+ "living room": "modern cozy living room with soft sunlight and light grey carpet",
368
+ "chair": "standalone dark oak chair with no surroundings and clean empty seat",
369
+ "rope": "twisted hemp rope with rough fibers and dusty beige texture",
370
+ "box": "slightly crumpled cardboard box with open flaps and brown textured surface",
371
+ "magazine": "celebrity magazine with glossy red cover and large bold title"
372
+ }
373
+ ```
374
+ """
375
+
376
+
377
+ class LayoutDesigner(object):
378
+ def __init__(
379
+ self,
380
+ gpt_client: GPTclient,
381
+ system_prompt: str,
382
+ verbose: bool = False,
383
+ ) -> None:
384
+ self.prompt = system_prompt.strip()
385
+ self.verbose = verbose
386
+ self.gpt_client = gpt_client
387
+
388
+ def query(self, prompt: str, params: dict = None) -> str:
389
+ full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
390
+
391
+ response = self.gpt_client.query(
392
+ text_prompt=full_prompt,
393
+ params=params,
394
+ )
395
+
396
+ if self.verbose:
397
+ logger.info(f"Response: {response}")
398
+
399
+ return response
400
+
401
+ def format_response(self, response: str) -> dict:
402
+ cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
403
+ try:
404
+ output = json.loads(cleaned)
405
+ except json.JSONDecodeError as e:
406
+ raise json.JSONDecodeError(
407
+ f"Error: {e}, failed to parse JSON response: {response}"
408
+ )
409
+
410
+ return output
411
+
412
+ def format_response_repair(self, response: str) -> dict:
413
+ return json_repair.loads(response)
414
+
415
+ def save_output(self, output: dict, save_path: str) -> None:
416
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
417
+ with open(save_path, 'w') as f:
418
+ json.dump(output, f, indent=4)
419
+
420
+ def __call__(
421
+ self, prompt: str, save_path: str = None, params: dict = None
422
+ ) -> dict | str:
423
+ response = self.query(prompt, params=params)
424
+ output = self.format_response_repair(response)
425
+ self.save_output(output, save_path) if save_path else None
426
+
427
+ return output
428
+
429
+
430
+ LAYOUT_DISASSEMBLER = LayoutDesigner(
431
+ gpt_client=GPT_CLIENT, system_prompt=LAYOUT_DISASSEMBLE_PROMPT
432
+ )
433
+ LAYOUT_GRAPHER = LayoutDesigner(
434
+ gpt_client=GPT_CLIENT, system_prompt=LAYOUT_HIERARCHY_PROMPT
435
+ )
436
+ LAYOUT_DESCRIBER = LayoutDesigner(
437
+ gpt_client=GPT_CLIENT, system_prompt=LAYOUT_DESCRIBER_PROMPT
438
+ )
439
+
440
+
441
+ def build_scene_layout(
442
+ task_desc: str, output_path: str = None, gpt_params: dict = None
443
+ ) -> LayoutInfo:
444
+ layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
445
+ layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
446
+ object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
447
+ obj_prompt = f'{layout_relation["task_desc"]} {object_mapping}'
448
+ objs_desc = LAYOUT_DESCRIBER(obj_prompt, params=gpt_params)
449
+ layout_info = LayoutInfo(
450
+ layout_tree, layout_relation, objs_desc, object_mapping
451
+ )
452
+
453
+ if output_path is not None:
454
+ visualizer = SceneTreeVisualizer(layout_info)
455
+ visualizer.render(save_path=output_path)
456
+ logger.info(f"Scene hierarchy tree saved to {output_path}")
457
+
458
+ return layout_info
459
+
460
+
461
+ def parse_args():
462
+ parser = argparse.ArgumentParser(description="3D Scene Layout Designer")
463
+ parser.add_argument(
464
+ "--task_desc",
465
+ type=str,
466
+ default="Put the apples on the table on the plate",
467
+ help="Natural language description of the robotic task",
468
+ )
469
+ parser.add_argument(
470
+ "--save_root",
471
+ type=str,
472
+ default="outputs/layout_tree",
473
+ help="Path to save the layout output",
474
+ )
475
+ return parser.parse_args()
476
+
477
+
478
+ if __name__ == "__main__":
479
+ from embodied_gen.utils.enum import LayoutInfo
480
+ from embodied_gen.utils.process_media import SceneTreeVisualizer
481
+
482
+ args = parse_args()
483
+ params = {
484
+ "temperature": 1.0,
485
+ "top_p": 0.95,
486
+ "frequency_penalty": 0.3,
487
+ "presence_penalty": 0.5,
488
+ }
489
+ layout_relation = LAYOUT_DISASSEMBLER(args.task_desc, params=params)
490
+ layout_tree = LAYOUT_GRAPHER(layout_relation, params=params)
491
+
492
+ object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
493
+ obj_prompt = f'{layout_relation["task_desc"]} {object_mapping}'
494
+
495
+ objs_desc = LAYOUT_DESCRIBER(obj_prompt, params=params)
496
+
497
+ layout_info = LayoutInfo(layout_tree, layout_relation, objs_desc)
498
+
499
+ visualizer = SceneTreeVisualizer(layout_info)
500
+ os.makedirs(args.save_root, exist_ok=True)
501
+ scene_graph_path = f"{args.save_root}/scene_tree.jpg"
502
+ visualizer.render(save_path=scene_graph_path)
503
+ with open(f"{args.save_root}/layout.json", "w") as f:
504
+ json.dump(layout_info.to_dict(), f, indent=4)
505
+
506
+ print(f"Scene hierarchy tree saved to {scene_graph_path}")
507
+ print(f"Disassembled Layout: {layout_relation}")
508
+ print(f"Layout Graph: {layout_tree}")
509
+ print(f"Layout Descriptions: {objs_desc}")
embodied_gen/scripts/compose_layout.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ from dataclasses import dataclass
20
+
21
+ import tyro
22
+ from embodied_gen.scripts.simulate_sapien import entrypoint as sim_cli
23
+ from embodied_gen.utils.enum import LayoutInfo
24
+ from embodied_gen.utils.geometry import bfs_placement, compose_mesh_scene
25
+ from embodied_gen.utils.log import logger
26
+
27
+
28
+ @dataclass
29
+ class LayoutPlacementConfig:
30
+ layout_path: str
31
+ output_dir: str | None = None
32
+ seed: int | None = None
33
+ max_attempts: int = 1000
34
+ output_iscene: bool = False
35
+ insert_robot: bool = False
36
+
37
+
38
+ def entrypoint(**kwargs):
39
+ if kwargs is None or len(kwargs) == 0:
40
+ args = tyro.cli(LayoutPlacementConfig)
41
+ else:
42
+ args = LayoutPlacementConfig(**kwargs)
43
+
44
+ output_dir = (
45
+ args.output_dir
46
+ if args.output_dir is not None
47
+ else os.path.dirname(args.layout_path)
48
+ )
49
+ os.makedirs(output_dir, exist_ok=True)
50
+ out_scene_path = f"{output_dir}/Iscene.glb"
51
+ out_layout_path = f"{output_dir}/layout.json"
52
+
53
+ with open(args.layout_path, "r") as f:
54
+ layout_info = LayoutInfo.from_dict(json.load(f))
55
+
56
+ layout_info = bfs_placement(layout_info, seed=args.seed)
57
+ with open(out_layout_path, "w") as f:
58
+ json.dump(layout_info.to_dict(), f, indent=4)
59
+
60
+ if args.output_iscene:
61
+ compose_mesh_scene(layout_info, out_scene_path)
62
+
63
+ sim_cli(
64
+ layout_path=out_layout_path,
65
+ output_dir=output_dir,
66
+ robot_name="franka" if args.insert_robot else None,
67
+ )
68
+
69
+ logger.info(f"Layout placement completed in {output_dir}")
70
+
71
+
72
+ if __name__ == "__main__":
73
+ entrypoint()
embodied_gen/scripts/gen_layout.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import gc
18
+ import json
19
+ import os
20
+ from dataclasses import dataclass, field
21
+ from shutil import copytree
22
+ from time import time
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import tyro
27
+ from embodied_gen.models.layout import build_scene_layout
28
+ from embodied_gen.scripts.simulate_sapien import entrypoint as sim_cli
29
+ from embodied_gen.scripts.textto3d import text_to_3d
30
+ from embodied_gen.utils.config import GptParamsConfig
31
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
32
+ from embodied_gen.utils.geometry import bfs_placement, compose_mesh_scene
33
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
34
+ from embodied_gen.utils.log import logger
35
+ from embodied_gen.utils.process_media import (
36
+ load_scene_dict,
37
+ parse_text_prompts,
38
+ )
39
+ from embodied_gen.validators.quality_checkers import SemanticMatcher
40
+
41
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
42
+
43
+
44
+ @dataclass
45
+ class LayoutGenConfig:
46
+ task_descs: list[str]
47
+ output_root: str
48
+ bg_list: str = "outputs/bg_scenes/scene_list.txt"
49
+ n_img_sample: int = 3
50
+ text_guidance_scale: float = 7.0
51
+ img_denoise_step: int = 25
52
+ n_image_retry: int = 4
53
+ n_asset_retry: int = 3
54
+ n_pipe_retry: int = 2
55
+ seed_img: Optional[int] = None
56
+ seed_3d: Optional[int] = None
57
+ seed_layout: Optional[int] = None
58
+ keep_intermediate: bool = False
59
+ output_iscene: bool = False
60
+ insert_robot: bool = False
61
+ gpt_params: GptParamsConfig = field(
62
+ default_factory=lambda: GptParamsConfig(
63
+ temperature=1.0,
64
+ top_p=0.95,
65
+ frequency_penalty=0.3,
66
+ presence_penalty=0.5,
67
+ )
68
+ )
69
+
70
+
71
+ def entrypoint() -> None:
72
+ args = tyro.cli(LayoutGenConfig)
73
+ SCENE_MATCHER = SemanticMatcher(GPT_CLIENT)
74
+ task_descs = parse_text_prompts(args.task_descs)
75
+ scene_dict = load_scene_dict(args.bg_list)
76
+ gpt_params = args.gpt_params.to_dict()
77
+ for idx, task_desc in enumerate(task_descs):
78
+ logger.info(f"Generate Layout and 3D scene for task: {task_desc}")
79
+ output_root = f"{args.output_root}/task_{idx:04d}"
80
+ scene_graph_path = f"{output_root}/scene_tree.jpg"
81
+ start_time = time()
82
+ layout_info: LayoutInfo = build_scene_layout(
83
+ task_desc, scene_graph_path, gpt_params
84
+ )
85
+ prompts_mapping = {v: k for k, v in layout_info.objs_desc.items()}
86
+ prompts = [
87
+ v
88
+ for k, v in layout_info.objs_desc.items()
89
+ if layout_info.objs_mapping[k] != Scene3DItemEnum.BACKGROUND.value
90
+ ]
91
+
92
+ for prompt in prompts:
93
+ node = prompts_mapping[prompt]
94
+ generation_log = text_to_3d(
95
+ prompts=[
96
+ prompt,
97
+ ],
98
+ output_root=output_root,
99
+ asset_names=[
100
+ node,
101
+ ],
102
+ n_img_sample=args.n_img_sample,
103
+ text_guidance_scale=args.text_guidance_scale,
104
+ img_denoise_step=args.img_denoise_step,
105
+ n_image_retry=args.n_image_retry,
106
+ n_asset_retry=args.n_asset_retry,
107
+ n_pipe_retry=args.n_pipe_retry,
108
+ seed_img=args.seed_img,
109
+ seed_3d=args.seed_3d,
110
+ keep_intermediate=args.keep_intermediate,
111
+ )
112
+ layout_info.assets.update(generation_log["assets"])
113
+ layout_info.quality.update(generation_log["quality"])
114
+
115
+ # Background GEN (for efficiency, temp use retrieval instead)
116
+ bg_node = layout_info.relation[Scene3DItemEnum.BACKGROUND.value]
117
+ text = layout_info.objs_desc[bg_node]
118
+ match_key = SCENE_MATCHER.query(text, str(scene_dict))
119
+ match_scene_path = f"{os.path.dirname(args.bg_list)}/{match_key}"
120
+ bg_save_dir = os.path.join(output_root, "background")
121
+ copytree(match_scene_path, bg_save_dir, dirs_exist_ok=True)
122
+ layout_info.assets[bg_node] = bg_save_dir
123
+
124
+ # BFS layout placement.
125
+ layout_info = bfs_placement(
126
+ layout_info,
127
+ limit_reach_range=True if args.insert_robot else False,
128
+ seed=args.seed_layout,
129
+ )
130
+ layout_path = f"{output_root}/layout.json"
131
+ with open(layout_path, "w") as f:
132
+ json.dump(layout_info.to_dict(), f, indent=4)
133
+
134
+ if args.output_iscene:
135
+ compose_mesh_scene(layout_info, f"{output_root}/Iscene.glb")
136
+
137
+ sim_cli(
138
+ layout_path=layout_path,
139
+ output_dir=output_root,
140
+ robot_name="franka" if args.insert_robot else None,
141
+ )
142
+
143
+ torch.cuda.empty_cache()
144
+ gc.collect()
145
+
146
+ elapsed_time = (time() - start_time) / 60
147
+ logger.info(
148
+ f"Layout generation done for {scene_graph_path}, layout result "
149
+ f"in {layout_path}, finished in {elapsed_time:.2f} mins."
150
+ )
151
+
152
+ logger.info(f"All tasks completed in {args.output_root}")
153
+
154
+
155
+ if __name__ == "__main__":
156
+ entrypoint()
embodied_gen/scripts/imageto3d.py CHANGED
@@ -58,7 +58,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
58
  os.environ["SPCONV_ALGO"] = "native"
59
  random.seed(0)
60
 
61
- logger.info("Loading Models...")
62
  DELIGHT = DelightingModel()
63
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
64
  RBG_REMOVER = RembgRemover()
@@ -107,6 +107,7 @@ def parse_args():
107
  type=int,
108
  default=2,
109
  )
 
110
  args, unknown = parser.parse_known_args()
111
 
112
  return args
@@ -151,6 +152,9 @@ def entrypoint(**kwargs):
151
  seg_image.save(seg_path)
152
 
153
  seed = args.seed
 
 
 
154
  for try_idx in range(args.n_retry):
155
  logger.info(
156
  f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
@@ -207,7 +211,9 @@ def entrypoint(**kwargs):
207
  color_path = os.path.join(output_root, "color.png")
208
  render_gs_api(aligned_gs_path, color_path)
209
 
210
- geo_flag, geo_result = GEO_CHECKER([color_path])
 
 
211
  logger.warning(
212
  f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
213
  )
@@ -246,7 +252,11 @@ def entrypoint(**kwargs):
246
  mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
247
  mesh.export(mesh_glb_path)
248
 
249
- urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
 
 
 
 
250
  asset_attrs = {
251
  "version": VERSION,
252
  "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
 
58
  os.environ["SPCONV_ALGO"] = "native"
59
  random.seed(0)
60
 
61
+ logger.info("Loading Image3D Models...")
62
  DELIGHT = DelightingModel()
63
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
64
  RBG_REMOVER = RembgRemover()
 
107
  type=int,
108
  default=2,
109
  )
110
+ parser.add_argument("--disable_decompose_convex", action="store_true")
111
  args, unknown = parser.parse_known_args()
112
 
113
  return args
 
152
  seg_image.save(seg_path)
153
 
154
  seed = args.seed
155
+ asset_node = "unknown"
156
+ if isinstance(args.asset_type, list) and args.asset_type[idx]:
157
+ asset_node = args.asset_type[idx]
158
  for try_idx in range(args.n_retry):
159
  logger.info(
160
  f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
 
211
  color_path = os.path.join(output_root, "color.png")
212
  render_gs_api(aligned_gs_path, color_path)
213
 
214
+ geo_flag, geo_result = GEO_CHECKER(
215
+ [color_path], text=asset_node
216
+ )
217
  logger.warning(
218
  f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
219
  )
 
252
  mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
253
  mesh.export(mesh_glb_path)
254
 
255
+ urdf_convertor = URDFGenerator(
256
+ GPT_CLIENT,
257
+ render_view_num=4,
258
+ decompose_convex=not args.disable_decompose_convex,
259
+ )
260
  asset_attrs = {
261
  "version": VERSION,
262
  "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
embodied_gen/scripts/parallel_sim.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ from embodied_gen.utils.monkey_patches import monkey_patch_maniskill
19
+
20
+ monkey_patch_maniskill()
21
+ import json
22
+ from collections import defaultdict
23
+ from dataclasses import dataclass
24
+ from typing import Literal
25
+
26
+ import gymnasium as gym
27
+ import numpy as np
28
+ import torch
29
+ import tyro
30
+ from mani_skill.utils.wrappers import RecordEpisode
31
+ from tqdm import tqdm
32
+ import embodied_gen.envs.pick_embodiedgen
33
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
34
+ from embodied_gen.utils.log import logger
35
+ from embodied_gen.utils.simulation import FrankaPandaGrasper
36
+
37
+
38
+ @dataclass
39
+ class ParallelSimConfig:
40
+ """CLI parameters for Parallel Sapien simulation."""
41
+
42
+ # Environment configuration
43
+ layout_file: str
44
+ """Path to the layout JSON file"""
45
+ output_dir: str
46
+ """Directory to save recorded videos"""
47
+ gym_env_name: str = "PickEmbodiedGen-v1"
48
+ """Name of the Gym environment to use"""
49
+ num_envs: int = 4
50
+ """Number of parallel environments"""
51
+ render_mode: Literal["rgb_array", "hybrid"] = "hybrid"
52
+ """Rendering mode: rgb_array or hybrid"""
53
+ enable_shadow: bool = True
54
+ """Whether to enable shadows in rendering"""
55
+ control_mode: str = "pd_joint_pos"
56
+ """Control mode for the agent"""
57
+
58
+ # Recording configuration
59
+ max_steps_per_video: int = 1000
60
+ """Maximum steps to record per video"""
61
+ save_trajectory: bool = False
62
+ """Whether to save trajectory data"""
63
+
64
+ # Simulation parameters
65
+ seed: int = 0
66
+ """Random seed for environment reset"""
67
+ warmup_steps: int = 50
68
+ """Number of warmup steps before action computation"""
69
+ reach_target_only: bool = True
70
+ """Whether to only reach target without full action"""
71
+
72
+
73
+ def entrypoint(**kwargs):
74
+ if kwargs is None or len(kwargs) == 0:
75
+ cfg = tyro.cli(ParallelSimConfig)
76
+ else:
77
+ cfg = ParallelSimConfig(**kwargs)
78
+
79
+ env = gym.make(
80
+ cfg.gym_env_name,
81
+ num_envs=cfg.num_envs,
82
+ render_mode=cfg.render_mode,
83
+ enable_shadow=cfg.enable_shadow,
84
+ layout_file=cfg.layout_file,
85
+ control_mode=cfg.control_mode,
86
+ )
87
+ env = RecordEpisode(
88
+ env,
89
+ cfg.output_dir,
90
+ max_steps_per_video=cfg.max_steps_per_video,
91
+ save_trajectory=cfg.save_trajectory,
92
+ )
93
+ env.reset(seed=cfg.seed)
94
+
95
+ default_action = env.unwrapped.agent.init_qpos[:, :8]
96
+ for _ in tqdm(range(cfg.warmup_steps), desc="SIM Warmup"):
97
+ # action = env.action_space.sample() # Random action
98
+ obs, reward, terminated, truncated, info = env.step(default_action)
99
+
100
+ grasper = FrankaPandaGrasper(
101
+ env.unwrapped.agent,
102
+ env.unwrapped.sim_config.control_freq,
103
+ )
104
+
105
+ layout_data = LayoutInfo.from_dict(json.load(open(cfg.layout_file, "r")))
106
+ actions = defaultdict(list)
107
+ # Plan Grasp reach pose for each manipulated object in each env.
108
+ for env_idx in range(env.num_envs):
109
+ actors = env.unwrapped.env_actors[f"env{env_idx}"]
110
+ for node in layout_data.relation[
111
+ Scene3DItemEnum.MANIPULATED_OBJS.value
112
+ ]:
113
+ action = grasper.compute_grasp_action(
114
+ actor=actors[node]._objs[0],
115
+ reach_target_only=True,
116
+ env_idx=env_idx,
117
+ )
118
+ actions[node].append(action)
119
+
120
+ # Excute the planned actions for each manipulated object in each env.
121
+ for node in actions:
122
+ max_env_steps = 0
123
+ for env_idx in range(env.num_envs):
124
+ if actions[node][env_idx] is None:
125
+ continue
126
+ max_env_steps = max(max_env_steps, len(actions[node][env_idx]))
127
+
128
+ action_tensor = np.ones(
129
+ (max_env_steps, env.num_envs, env.action_space.shape[-1])
130
+ )
131
+ action_tensor *= default_action[None, ...]
132
+ for env_idx in range(env.num_envs):
133
+ action = actions[node][env_idx]
134
+ if action is None:
135
+ continue
136
+ action_tensor[: len(action), env_idx, :] = action
137
+
138
+ for step in tqdm(range(max_env_steps), desc=f"Grasping: {node}"):
139
+ action = torch.Tensor(action_tensor[step]).to(env.unwrapped.device)
140
+ env.unwrapped.agent.set_action(action)
141
+ obs, reward, terminated, truncated, info = env.step(action)
142
+
143
+ env.close()
144
+ logger.info(f"Results saved in {cfg.output_dir}")
145
+
146
+
147
+ if __name__ == "__main__":
148
+ entrypoint()
embodied_gen/scripts/simulate_sapien.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import json
19
+ import os
20
+ from collections import defaultdict
21
+ from dataclasses import dataclass, field
22
+ from typing import Literal
23
+
24
+ import imageio
25
+ import numpy as np
26
+ import torch
27
+ import tyro
28
+ from tqdm import tqdm
29
+ from embodied_gen.models.gs_model import GaussianOperator
30
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
31
+ from embodied_gen.utils.geometry import quaternion_multiply
32
+ from embodied_gen.utils.log import logger
33
+ from embodied_gen.utils.process_media import alpha_blend_rgba
34
+ from embodied_gen.utils.simulation import (
35
+ SIM_COORD_ALIGN,
36
+ FrankaPandaGrasper,
37
+ SapienSceneManager,
38
+ load_assets_from_layout_file,
39
+ load_mani_skill_robot,
40
+ render_images,
41
+ )
42
+
43
+
44
+ @dataclass
45
+ class SapienSimConfig:
46
+ # Simulation settings.
47
+ layout_path: str
48
+ output_dir: str
49
+ sim_freq: int = 200
50
+ sim_step: int = 400
51
+ z_offset: float = 0.004
52
+ init_quat: list[float] = field(
53
+ default_factory=lambda: [0.7071, 0, 0, 0.7071]
54
+ ) # xyzw
55
+ device: str = "cuda"
56
+ control_freq: int = 50
57
+ insert_robot: bool = False
58
+ # Camera settings.
59
+ render_interval: int = 10
60
+ num_cameras: int = 3
61
+ camera_radius: float = 0.9
62
+ camera_height: float = 1.1
63
+ image_hw: tuple[int, int] = (512, 512)
64
+ ray_tracing: bool = True
65
+ fovy_deg: float = 75.0
66
+ camera_target_pt: list[float] = field(
67
+ default_factory=lambda: [0.0, 0.0, 0.9]
68
+ )
69
+ render_keys: list[
70
+ Literal[
71
+ "Color", "Foreground", "Segmentation", "Normal", "Mask", "Depth"
72
+ ]
73
+ ] = field(default_factory=lambda: ["Foreground"])
74
+
75
+
76
+ def entrypoint(**kwargs):
77
+ if kwargs is None or len(kwargs) == 0:
78
+ cfg = tyro.cli(SapienSimConfig)
79
+ else:
80
+ cfg = SapienSimConfig(**kwargs)
81
+
82
+ scene_manager = SapienSceneManager(
83
+ cfg.sim_freq, ray_tracing=cfg.ray_tracing
84
+ )
85
+ _ = scene_manager.initialize_circular_cameras(
86
+ num_cameras=cfg.num_cameras,
87
+ radius=cfg.camera_radius,
88
+ height=cfg.camera_height,
89
+ target_pt=cfg.camera_target_pt,
90
+ image_hw=cfg.image_hw,
91
+ fovy_deg=cfg.fovy_deg,
92
+ )
93
+ with open(cfg.layout_path, "r") as f:
94
+ layout_data = json.load(f)
95
+ layout_data: LayoutInfo = LayoutInfo.from_dict(layout_data)
96
+
97
+ actors = load_assets_from_layout_file(
98
+ scene_manager.scene,
99
+ layout_data,
100
+ cfg.z_offset,
101
+ cfg.init_quat,
102
+ )
103
+ agent = load_mani_skill_robot(
104
+ scene_manager.scene, layout_data, cfg.control_freq
105
+ )
106
+
107
+ frames = defaultdict(list)
108
+ image_cnt = 0
109
+ for step in tqdm(range(cfg.sim_step), desc="Simulation"):
110
+ scene_manager.scene.step()
111
+ agent.reset(agent.init_qpos)
112
+ if step % cfg.render_interval != 0:
113
+ continue
114
+ scene_manager.scene.update_render()
115
+ image_cnt += 1
116
+ for camera in scene_manager.cameras:
117
+ camera.take_picture()
118
+ images = render_images(camera, cfg.render_keys)
119
+ frames[camera.name].append(images)
120
+
121
+ actions = dict()
122
+ if cfg.insert_robot:
123
+ grasper = FrankaPandaGrasper(
124
+ agent,
125
+ cfg.control_freq,
126
+ )
127
+ for node in layout_data.relation[
128
+ Scene3DItemEnum.MANIPULATED_OBJS.value
129
+ ]:
130
+ actions[node] = grasper.compute_grasp_action(
131
+ actor=actors[node], reach_target_only=True
132
+ )
133
+
134
+ if "Foreground" not in cfg.render_keys:
135
+ return
136
+
137
+ bg_node = layout_data.relation[Scene3DItemEnum.BACKGROUND.value]
138
+ gs_path = f"{layout_data.assets[bg_node]}/gs_model.ply"
139
+ gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
140
+ x, y, z, qx, qy, qz, qw = layout_data.position[bg_node]
141
+ qx, qy, qz, qw = quaternion_multiply([qx, qy, qz, qw], cfg.init_quat)
142
+ init_pose = torch.tensor([x, y, z, qx, qy, qz, qw])
143
+ gs_model = gs_model.get_gaussians(instance_pose=init_pose)
144
+
145
+ bg_images = dict()
146
+ for camera in scene_manager.cameras:
147
+ Ks = camera.get_intrinsic_matrix()
148
+ c2w = camera.get_model_matrix()
149
+ c2w = c2w @ SIM_COORD_ALIGN
150
+ result = gs_model.render(
151
+ torch.tensor(c2w, dtype=torch.float32).to(cfg.device),
152
+ torch.tensor(Ks, dtype=torch.float32).to(cfg.device),
153
+ image_width=cfg.image_hw[1],
154
+ image_height=cfg.image_hw[0],
155
+ )
156
+ bg_images[camera.name] = result.rgb[..., ::-1]
157
+
158
+ video_frames = []
159
+ for camera in scene_manager.cameras:
160
+ # Scene rendering
161
+ for step in range(image_cnt):
162
+ rgba = alpha_blend_rgba(
163
+ frames[camera.name][step]["Foreground"],
164
+ bg_images[camera.name],
165
+ )
166
+ video_frames.append(np.array(rgba))
167
+
168
+ # Grasp rendering
169
+ for node in actions:
170
+ if actions[node] is None:
171
+ continue
172
+ for action in tqdm(actions[node]):
173
+ grasp_frames = scene_manager.step_action(
174
+ agent,
175
+ torch.Tensor(action[None, ...]),
176
+ scene_manager.cameras,
177
+ cfg.render_keys,
178
+ sim_steps_per_control=cfg.sim_freq // cfg.control_freq,
179
+ )
180
+ rgba = alpha_blend_rgba(
181
+ grasp_frames[camera.name][0]["Foreground"],
182
+ bg_images[camera.name],
183
+ )
184
+ video_frames.append(np.array(rgba))
185
+
186
+ agent.reset(agent.init_qpos)
187
+
188
+ os.makedirs(cfg.output_dir, exist_ok=True)
189
+ video_path = f"{cfg.output_dir}/Iscene.mp4"
190
+ imageio.mimsave(video_path, video_frames, fps=30)
191
+ logger.info(f"Interative 3D Scene Visualization saved in {video_path}")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ entrypoint()
embodied_gen/scripts/textto3d.py CHANGED
@@ -42,7 +42,7 @@ from embodied_gen.validators.quality_checkers import (
42
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
43
  random.seed(0)
44
 
45
- logger.info("Loading Models...")
46
  SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
47
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
48
  TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
@@ -170,6 +170,7 @@ def text_to_3d(**kwargs) -> dict:
170
  seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
171
  n_retry=args.n_asset_retry,
172
  keep_intermediate=args.keep_intermediate,
 
173
  )
174
  mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
175
  image_path = render_asset3d(
@@ -270,6 +271,7 @@ def parse_args():
270
  help="Random seed for 3D generation",
271
  )
272
  parser.add_argument("--keep_intermediate", action="store_true")
 
273
 
274
  args, unknown = parser.parse_known_args()
275
 
 
42
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
43
  random.seed(0)
44
 
45
+ logger.info("Loading TEXT2IMG_MODEL...")
46
  SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
47
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
48
  TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
 
170
  seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
171
  n_retry=args.n_asset_retry,
172
  keep_intermediate=args.keep_intermediate,
173
+ disable_decompose_convex=args.disable_decompose_convex,
174
  )
175
  mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
176
  image_path = render_asset3d(
 
271
  help="Random seed for 3D generation",
272
  )
273
  parser.add_argument("--keep_intermediate", action="store_true")
274
+ parser.add_argument("--disable_decompose_convex", action="store_true")
275
 
276
  args, unknown = parser.parse_known_args()
277
 
embodied_gen/scripts/textto3d.sh CHANGED
@@ -81,6 +81,7 @@ done
81
 
82
 
83
  # Step 1: Text-to-Image
 
84
  eval python3 embodied_gen/scripts/text2image.py \
85
  --prompts ${prompt_args} \
86
  --output_root "${output_root}/images" \
 
81
 
82
 
83
  # Step 1: Text-to-Image
84
+ echo ${prompt_args}
85
  eval python3 embodied_gen/scripts/text2image.py \
86
  --prompts ${prompt_args} \
87
  --output_root "${output_root}/images" \
embodied_gen/trainer/gsplat_trainer.py CHANGED
@@ -617,7 +617,7 @@ class Runner:
617
  for rgb, depth in images_cache:
618
  depth_normalized = torch.clip(
619
  (depth - depth_global_min)
620
- / (depth_global_max - depth_global_min),
621
  0,
622
  1,
623
  )
 
617
  for rgb, depth in images_cache:
618
  depth_normalized = torch.clip(
619
  (depth - depth_global_min)
620
+ / (depth_global_max - depth_global_min + 1e-8),
621
  0,
622
  1,
623
  )
embodied_gen/trainer/pono2mesh_trainer.py CHANGED
@@ -30,7 +30,7 @@ from kornia.morphology import dilation
30
  from PIL import Image
31
  from embodied_gen.models.sr_model import ImageRealESRGAN
32
  from embodied_gen.utils.config import Pano2MeshSRConfig
33
- from embodied_gen.utils.gaussian import compute_pinhole_intrinsics
34
  from embodied_gen.utils.log import logger
35
  from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
36
  from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
 
30
  from PIL import Image
31
  from embodied_gen.models.sr_model import ImageRealESRGAN
32
  from embodied_gen.utils.config import Pano2MeshSRConfig
33
+ from embodied_gen.utils.geometry import compute_pinhole_intrinsics
34
  from embodied_gen.utils.log import logger
35
  from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
36
  from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
embodied_gen/utils/config.py CHANGED
@@ -17,15 +17,27 @@
17
  from dataclasses import dataclass, field
18
  from typing import List, Optional, Union
19
 
 
20
  from gsplat.strategy import DefaultStrategy, MCMCStrategy
21
  from typing_extensions import Literal, assert_never
22
 
23
  __all__ = [
 
24
  "Pano2MeshSRConfig",
25
  "GsplatTrainConfig",
26
  ]
27
 
28
 
 
 
 
 
 
 
 
 
 
 
29
  @dataclass
30
  class Pano2MeshSRConfig:
31
  mesh_file: str = "mesh_model.ply"
 
17
  from dataclasses import dataclass, field
18
  from typing import List, Optional, Union
19
 
20
+ from dataclasses_json import DataClassJsonMixin
21
  from gsplat.strategy import DefaultStrategy, MCMCStrategy
22
  from typing_extensions import Literal, assert_never
23
 
24
  __all__ = [
25
+ "GptParamsConfig",
26
  "Pano2MeshSRConfig",
27
  "GsplatTrainConfig",
28
  ]
29
 
30
 
31
+ @dataclass
32
+ class GptParamsConfig(DataClassJsonMixin):
33
+ temperature: float = 0.1
34
+ top_p: float = 0.1
35
+ frequency_penalty: float = 0.0
36
+ presence_penalty: float = 0.0
37
+ stop: int | None = None
38
+ max_tokens: int = 500
39
+
40
+
41
  @dataclass
42
  class Pano2MeshSRConfig:
43
  mesh_file: str = "mesh_model.ply"
embodied_gen/utils/enum.py CHANGED
@@ -102,6 +102,7 @@ class LayoutInfo(DataClassJsonMixin):
102
  tree: dict[str, list]
103
  relation: dict[str, str | list[str]]
104
  objs_desc: dict[str, str] = field(default_factory=dict)
 
105
  assets: dict[str, str] = field(default_factory=dict)
106
  quality: dict[str, str] = field(default_factory=dict)
107
  position: dict[str, list[float]] = field(default_factory=dict)
 
102
  tree: dict[str, list]
103
  relation: dict[str, str | list[str]]
104
  objs_desc: dict[str, str] = field(default_factory=dict)
105
+ objs_mapping: dict[str, str] = field(default_factory=dict)
106
  assets: dict[str, str] = field(default_factory=dict)
107
  quality: dict[str, str] = field(default_factory=dict)
108
  position: dict[str, list[float]] = field(default_factory=dict)
embodied_gen/utils/gaussian.py CHANGED
@@ -35,7 +35,6 @@ __all__ = [
35
  "set_random_seed",
36
  "export_splats",
37
  "create_splats_with_optimizers",
38
- "compute_pinhole_intrinsics",
39
  "resize_pinhole_intrinsics",
40
  "restore_scene_scale_and_position",
41
  ]
@@ -265,12 +264,12 @@ def create_splats_with_optimizers(
265
  return splats, optimizers
266
 
267
 
268
- def compute_pinhole_intrinsics(
269
- image_w: int, image_h: int, fov_deg: float
270
  ) -> np.ndarray:
271
- fov_rad = np.deg2rad(fov_deg)
272
- fx = image_w / (2 * np.tan(fov_rad / 2))
273
- fy = fx # assuming square pixels
274
  cx = image_w / 2
275
  cy = image_h / 2
276
  K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
 
35
  "set_random_seed",
36
  "export_splats",
37
  "create_splats_with_optimizers",
 
38
  "resize_pinhole_intrinsics",
39
  "restore_scene_scale_and_position",
40
  ]
 
264
  return splats, optimizers
265
 
266
 
267
+ def compute_intrinsics_from_fovy(
268
+ image_w: int, image_h: int, fovy_deg: float
269
  ) -> np.ndarray:
270
+ fovy_rad = np.deg2rad(fovy_deg)
271
+ fy = image_h / (2 * np.tan(fovy_rad / 2))
272
+ fx = fy * (image_w / image_h)
273
  cx = image_w / 2
274
  cy = image_h / 2
275
  K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
embodied_gen/utils/geometry.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import os
18
+ import random
19
+ from collections import defaultdict, deque
20
+ from functools import wraps
21
+ from typing import Literal
22
+
23
+ import numpy as np
24
+ import torch
25
+ import trimesh
26
+ from matplotlib.path import Path
27
+ from pyquaternion import Quaternion
28
+ from scipy.spatial import ConvexHull
29
+ from scipy.spatial.transform import Rotation as R
30
+ from shapely.geometry import Polygon
31
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
32
+ from embodied_gen.utils.log import logger
33
+
34
+ __all__ = [
35
+ "bfs_placement",
36
+ "with_seed",
37
+ "matrix_to_pose",
38
+ "pose_to_matrix",
39
+ "quaternion_multiply",
40
+ "check_reachable",
41
+ "bfs_placement",
42
+ "compose_mesh_scene",
43
+ "compute_pinhole_intrinsics",
44
+ ]
45
+
46
+
47
+ def matrix_to_pose(matrix: np.ndarray) -> list[float]:
48
+ """Convert a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
49
+
50
+ Args:
51
+ matrix (np.ndarray): 4x4 transformation matrix.
52
+
53
+ Returns:
54
+ List[float]: Pose as [x, y, z, qx, qy, qz, qw].
55
+ """
56
+ x, y, z = matrix[:3, 3]
57
+ rot_mat = matrix[:3, :3]
58
+ quat = R.from_matrix(rot_mat).as_quat()
59
+ qx, qy, qz, qw = quat
60
+
61
+ return [x, y, z, qx, qy, qz, qw]
62
+
63
+
64
+ def pose_to_matrix(pose: list[float]) -> np.ndarray:
65
+ """Convert pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
66
+
67
+ Args:
68
+ List[float]: Pose as [x, y, z, qx, qy, qz, qw].
69
+
70
+ Returns:
71
+ matrix (np.ndarray): 4x4 transformation matrix.
72
+ """
73
+ x, y, z, qx, qy, qz, qw = pose
74
+ r = R.from_quat([qx, qy, qz, qw])
75
+ matrix = np.eye(4)
76
+ matrix[:3, :3] = r.as_matrix()
77
+ matrix[:3, 3] = [x, y, z]
78
+
79
+ return matrix
80
+
81
+
82
+ def compute_xy_bbox(
83
+ vertices: np.ndarray, col_x: int = 0, col_y: int = 2
84
+ ) -> list[float]:
85
+ x_vals = vertices[:, col_x]
86
+ y_vals = vertices[:, col_y]
87
+ return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
88
+
89
+
90
+ def has_iou_conflict(
91
+ new_box: list[float],
92
+ placed_boxes: list[list[float]],
93
+ iou_threshold: float = 0.0,
94
+ ) -> bool:
95
+ new_min_x, new_max_x, new_min_y, new_max_y = new_box
96
+ for min_x, max_x, min_y, max_y in placed_boxes:
97
+ ix1 = max(new_min_x, min_x)
98
+ iy1 = max(new_min_y, min_y)
99
+ ix2 = min(new_max_x, max_x)
100
+ iy2 = min(new_max_y, max_y)
101
+ inter_area = max(0, ix2 - ix1) * max(0, iy2 - iy1)
102
+ if inter_area > iou_threshold:
103
+ return True
104
+ return False
105
+
106
+
107
+ def with_seed(seed_attr_name: str = "seed"):
108
+ """A parameterized decorator that temporarily sets the random seed."""
109
+
110
+ def decorator(func):
111
+ @wraps(func)
112
+ def wrapper(*args, **kwargs):
113
+ seed = kwargs.get(seed_attr_name, None)
114
+ if seed is not None:
115
+ py_state = random.getstate()
116
+ np_state = np.random.get_state()
117
+ torch_state = torch.get_rng_state()
118
+
119
+ random.seed(seed)
120
+ np.random.seed(seed)
121
+ torch.manual_seed(seed)
122
+ try:
123
+ result = func(*args, **kwargs)
124
+ finally:
125
+ random.setstate(py_state)
126
+ np.random.set_state(np_state)
127
+ torch.set_rng_state(torch_state)
128
+ return result
129
+ else:
130
+ return func(*args, **kwargs)
131
+
132
+ return wrapper
133
+
134
+ return decorator
135
+
136
+
137
+ def compute_convex_hull_path(
138
+ vertices: np.ndarray,
139
+ z_threshold: float = 0.05,
140
+ interp_per_edge: int = 3,
141
+ margin: float = -0.02,
142
+ ) -> Path:
143
+ top_vertices = vertices[
144
+ vertices[:, 1] > vertices[:, 1].max() - z_threshold
145
+ ]
146
+ top_xy = top_vertices[:, [0, 2]]
147
+
148
+ if len(top_xy) < 3:
149
+ raise ValueError("Not enough points to form a convex hull")
150
+
151
+ hull = ConvexHull(top_xy)
152
+ hull_points = top_xy[hull.vertices]
153
+
154
+ polygon = Polygon(hull_points)
155
+ polygon = polygon.buffer(margin)
156
+ hull_points = np.array(polygon.exterior.coords)
157
+
158
+ dense_points = []
159
+ for i in range(len(hull_points)):
160
+ p1 = hull_points[i]
161
+ p2 = hull_points[(i + 1) % len(hull_points)]
162
+ for t in np.linspace(0, 1, interp_per_edge, endpoint=False):
163
+ pt = (1 - t) * p1 + t * p2
164
+ dense_points.append(pt)
165
+
166
+ return Path(np.array(dense_points), closed=True)
167
+
168
+
169
+ def find_parent_node(node: str, tree: dict) -> str | None:
170
+ for parent, children in tree.items():
171
+ if any(child[0] == node for child in children):
172
+ return parent
173
+ return None
174
+
175
+
176
+ def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
177
+ x1, x2, y1, y2 = box
178
+ corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
179
+
180
+ num_inside = sum(hull.contains_point(c) for c in corners)
181
+ return num_inside >= threshold
182
+
183
+
184
+ def compute_axis_rotation_quat(
185
+ axis: Literal["x", "y", "z"], angle_rad: float
186
+ ) -> list[float]:
187
+ if axis.lower() == 'x':
188
+ q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
189
+ elif axis.lower() == 'y':
190
+ q = Quaternion(axis=[0, 1, 0], angle=angle_rad)
191
+ elif axis.lower() == 'z':
192
+ q = Quaternion(axis=[0, 0, 1], angle=angle_rad)
193
+ else:
194
+ raise ValueError(f"Unsupported axis '{axis}', must be one of x, y, z")
195
+
196
+ return [q.x, q.y, q.z, q.w]
197
+
198
+
199
+ def quaternion_multiply(
200
+ init_quat: list[float], rotate_quat: list[float]
201
+ ) -> list[float]:
202
+ qx, qy, qz, qw = init_quat
203
+ q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
204
+ qx, qy, qz, qw = rotate_quat
205
+ q2 = Quaternion(w=qw, x=qx, y=qy, z=qz)
206
+ quat = q2 * q1
207
+
208
+ return [quat.x, quat.y, quat.z, quat.w]
209
+
210
+
211
+ def check_reachable(
212
+ base_xyz: np.ndarray,
213
+ reach_xyz: np.ndarray,
214
+ min_reach: float = 0.25,
215
+ max_reach: float = 0.85,
216
+ ) -> bool:
217
+ """Check if the target point is within the reachable range."""
218
+ distance = np.linalg.norm(reach_xyz - base_xyz)
219
+
220
+ return min_reach < distance < max_reach
221
+
222
+
223
+ @with_seed("seed")
224
+ def bfs_placement(
225
+ layout_info: LayoutInfo,
226
+ floor_margin: float = 0,
227
+ beside_margin: float = 0.1,
228
+ max_attempts: int = 3000,
229
+ rotate_objs: bool = True,
230
+ rotate_bg: bool = True,
231
+ limit_reach_range: bool = True,
232
+ robot_dim: float = 0.12,
233
+ seed: int = None,
234
+ ) -> LayoutInfo:
235
+ object_mapping = layout_info.objs_mapping
236
+ position = {} # node: [x, y, z, qx, qy, qz, qw]
237
+ parent_bbox_xy = {}
238
+ placed_boxes_map = defaultdict(list)
239
+ mesh_info = defaultdict(dict)
240
+ robot_node = layout_info.relation[Scene3DItemEnum.ROBOT.value]
241
+ for node in object_mapping:
242
+ if object_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
243
+ bg_quat = (
244
+ compute_axis_rotation_quat(
245
+ axis="y",
246
+ angle_rad=np.random.uniform(0, 2 * np.pi),
247
+ )
248
+ if rotate_bg
249
+ else [0, 0, 0, 1]
250
+ )
251
+ bg_quat = [round(q, 4) for q in bg_quat]
252
+ continue
253
+
254
+ mesh_path = (
255
+ f"{layout_info.assets[node]}/mesh/{node.replace(' ', '_')}.obj"
256
+ )
257
+ mesh_info[node]["path"] = mesh_path
258
+ mesh = trimesh.load(mesh_path)
259
+ vertices = mesh.vertices
260
+ z1 = np.percentile(vertices[:, 1], 1)
261
+ z2 = np.percentile(vertices[:, 1], 99)
262
+
263
+ if object_mapping[node] == Scene3DItemEnum.CONTEXT.value:
264
+ object_quat = [0, 0, 0, 1]
265
+ mesh_info[node]["surface"] = compute_convex_hull_path(vertices)
266
+ # Put robot in the CONTEXT edge.
267
+ x, y = random.choice(mesh_info[node]["surface"].vertices)
268
+ theta = np.arctan2(y, x)
269
+ quat_initial = Quaternion(axis=[0, 0, 1], angle=theta)
270
+ quat_extra = Quaternion(axis=[0, 0, 1], angle=np.pi)
271
+ quat = quat_extra * quat_initial
272
+ _pose = [x, y, z2 - z1, quat.x, quat.y, quat.z, quat.w]
273
+ position[robot_node] = [round(v, 4) for v in _pose]
274
+ node_box = [
275
+ x - robot_dim / 2,
276
+ x + robot_dim / 2,
277
+ y - robot_dim / 2,
278
+ y + robot_dim / 2,
279
+ ]
280
+ placed_boxes_map[node].append(node_box)
281
+ elif rotate_objs:
282
+ # For manipulated and distractor objects, apply random rotation
283
+ angle_rad = np.random.uniform(0, 2 * np.pi)
284
+ object_quat = compute_axis_rotation_quat(
285
+ axis="y", angle_rad=angle_rad
286
+ )
287
+ object_quat_scipy = np.roll(object_quat, 1) # [w, x, y, z]
288
+ rotation = R.from_quat(object_quat_scipy).as_matrix()
289
+ vertices = np.dot(mesh.vertices, rotation.T)
290
+ z1 = np.percentile(vertices[:, 1], 1)
291
+ z2 = np.percentile(vertices[:, 1], 99)
292
+
293
+ x1, x2, y1, y2 = compute_xy_bbox(vertices)
294
+ mesh_info[node]["pose"] = [x1, x2, y1, y2, z1, z2, *object_quat]
295
+ mesh_info[node]["area"] = max(1e-5, (x2 - x1) * (y2 - y1))
296
+
297
+ root = list(layout_info.tree.keys())[0]
298
+ queue = deque([((root, None), layout_info.tree.get(root, []))])
299
+ while queue:
300
+ (node, relation), children = queue.popleft()
301
+ if node not in object_mapping:
302
+ continue
303
+
304
+ if object_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
305
+ position[node] = [0, 0, floor_margin, *bg_quat]
306
+ else:
307
+ x1, x2, y1, y2, z1, z2, qx, qy, qz, qw = mesh_info[node]["pose"]
308
+ if object_mapping[node] == Scene3DItemEnum.CONTEXT.value:
309
+ position[node] = [0, 0, -round(z1, 4), qx, qy, qz, qw]
310
+ parent_bbox_xy[node] = [x1, x2, y1, y2, z1, z2]
311
+ elif object_mapping[node] in [
312
+ Scene3DItemEnum.MANIPULATED_OBJS.value,
313
+ Scene3DItemEnum.DISTRACTOR_OBJS.value,
314
+ ]:
315
+ parent_node = find_parent_node(node, layout_info.tree)
316
+ parent_pos = position[parent_node]
317
+ (
318
+ p_x1,
319
+ p_x2,
320
+ p_y1,
321
+ p_y2,
322
+ p_z1,
323
+ p_z2,
324
+ ) = parent_bbox_xy[parent_node]
325
+
326
+ obj_dx = x2 - x1
327
+ obj_dy = y2 - y1
328
+ hull_path = mesh_info[parent_node].get("surface")
329
+ for _ in range(max_attempts):
330
+ node_x1 = random.uniform(p_x1, p_x2 - obj_dx)
331
+ node_y1 = random.uniform(p_y1, p_y2 - obj_dy)
332
+ node_box = [
333
+ node_x1,
334
+ node_x1 + obj_dx,
335
+ node_y1,
336
+ node_y1 + obj_dy,
337
+ ]
338
+ if hull_path and not all_corners_inside(
339
+ hull_path, node_box
340
+ ):
341
+ continue
342
+ # Make sure the manipulated object is reachable by robot.
343
+ if (
344
+ limit_reach_range
345
+ and object_mapping[node]
346
+ == Scene3DItemEnum.MANIPULATED_OBJS.value
347
+ ):
348
+ cx = parent_pos[0] + node_box[0] + obj_dx / 2
349
+ cy = parent_pos[1] + node_box[2] + obj_dy / 2
350
+ cz = parent_pos[2] + p_z2 - z1
351
+ robot_pose = position[robot_node][:3]
352
+ if not check_reachable(
353
+ base_xyz=np.array(robot_pose),
354
+ reach_xyz=np.array([cx, cy, cz]),
355
+ ):
356
+ continue
357
+
358
+ if not has_iou_conflict(
359
+ node_box, placed_boxes_map[parent_node]
360
+ ):
361
+ z_offset = 0
362
+ break
363
+ else:
364
+ logger.warning(
365
+ f"Cannot place {node} on {parent_node} without overlap"
366
+ f" after {max_attempts} attempts, place beside {parent_node}."
367
+ )
368
+ for _ in range(max_attempts):
369
+ node_x1 = random.choice(
370
+ [
371
+ random.uniform(
372
+ p_x1 - obj_dx - beside_margin,
373
+ p_x1 - obj_dx,
374
+ ),
375
+ random.uniform(p_x2, p_x2 + beside_margin),
376
+ ]
377
+ )
378
+ node_y1 = random.choice(
379
+ [
380
+ random.uniform(
381
+ p_y1 - obj_dy - beside_margin,
382
+ p_y1 - obj_dy,
383
+ ),
384
+ random.uniform(p_y2, p_y2 + beside_margin),
385
+ ]
386
+ )
387
+ node_box = [
388
+ node_x1,
389
+ node_x1 + obj_dx,
390
+ node_y1,
391
+ node_y1 + obj_dy,
392
+ ]
393
+ z_offset = -(parent_pos[2] + p_z2)
394
+ if not has_iou_conflict(
395
+ node_box, placed_boxes_map[parent_node]
396
+ ):
397
+ break
398
+
399
+ placed_boxes_map[parent_node].append(node_box)
400
+
401
+ abs_cx = parent_pos[0] + node_box[0] + obj_dx / 2
402
+ abs_cy = parent_pos[1] + node_box[2] + obj_dy / 2
403
+ abs_cz = parent_pos[2] + p_z2 - z1 + z_offset
404
+ position[node] = [
405
+ round(v, 4)
406
+ for v in [abs_cx, abs_cy, abs_cz, qx, qy, qz, qw]
407
+ ]
408
+ parent_bbox_xy[node] = [x1, x2, y1, y2, z1, z2]
409
+
410
+ sorted_children = sorted(
411
+ children, key=lambda x: -mesh_info[x[0]].get("area", 0)
412
+ )
413
+ for child, rel in sorted_children:
414
+ queue.append(((child, rel), layout_info.tree.get(child, [])))
415
+
416
+ layout_info.position = position
417
+
418
+ return layout_info
419
+
420
+
421
+ def compose_mesh_scene(
422
+ layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
423
+ ) -> None:
424
+ object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
425
+ scene = trimesh.Scene()
426
+ for node in layout_info.assets:
427
+ if object_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
428
+ mesh_path = f"{layout_info.assets[node]}/mesh_model.ply"
429
+ if not with_bg:
430
+ continue
431
+ else:
432
+ mesh_path = (
433
+ f"{layout_info.assets[node]}/mesh/{node.replace(' ', '_')}.obj"
434
+ )
435
+
436
+ mesh = trimesh.load(mesh_path)
437
+ offset = np.array(layout_info.position[node])[[0, 2, 1]]
438
+ mesh.vertices += offset
439
+ scene.add_geometry(mesh, node_name=node)
440
+
441
+ os.makedirs(os.path.dirname(out_scene_path), exist_ok=True)
442
+ scene.export(out_scene_path)
443
+ logger.info(f"Composed interactive 3D layout saved in {out_scene_path}")
444
+
445
+ return
446
+
447
+
448
+ def compute_pinhole_intrinsics(
449
+ image_w: int, image_h: int, fov_deg: float
450
+ ) -> np.ndarray:
451
+ fov_rad = np.deg2rad(fov_deg)
452
+ fx = image_w / (2 * np.tan(fov_rad / 2))
453
+ fy = fx # assuming square pixels
454
+ cx = image_w / 2
455
+ cy = image_h / 2
456
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
457
+
458
+ return K
embodied_gen/utils/monkey_patches.py CHANGED
@@ -18,6 +18,7 @@ import os
18
  import sys
19
  import zipfile
20
 
 
21
  import torch
22
  from huggingface_hub import hf_hub_download
23
  from omegaconf import OmegaConf
@@ -150,3 +151,68 @@ def monkey_patch_pano2room():
150
  self.inpaint_pipe = pipe
151
 
152
  SDFTInpainter.__init__ = patched_sd_inpaint_init
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  import sys
19
  import zipfile
20
 
21
+ import numpy as np
22
  import torch
23
  from huggingface_hub import hf_hub_download
24
  from omegaconf import OmegaConf
 
151
  self.inpaint_pipe = pipe
152
 
153
  SDFTInpainter.__init__ = patched_sd_inpaint_init
154
+
155
+
156
+ def monkey_patch_maniskill():
157
+ from mani_skill.envs.scene import ManiSkillScene
158
+
159
+ def get_sensor_images(
160
+ self, obs: dict[str, any]
161
+ ) -> dict[str, dict[str, torch.Tensor]]:
162
+ sensor_data = dict()
163
+ for name, sensor in self.sensors.items():
164
+ sensor_data[name] = sensor.get_images(obs[name])
165
+ return sensor_data
166
+
167
+ def get_human_render_camera_images(
168
+ self, camera_name: str = None, return_alpha: bool = False
169
+ ) -> dict[str, torch.Tensor]:
170
+ def get_rgba_tensor(camera, return_alpha):
171
+ color = camera.get_obs(
172
+ rgb=True, depth=False, segmentation=False, position=False
173
+ )["rgb"]
174
+ if return_alpha:
175
+ seg_labels = camera.get_obs(
176
+ rgb=False, depth=False, segmentation=True, position=False
177
+ )["segmentation"]
178
+ masks = np.where((seg_labels.cpu() > 0), 255, 0).astype(
179
+ np.uint8
180
+ )
181
+ masks = torch.tensor(masks).to(color.device)
182
+ color = torch.concat([color, masks], dim=-1)
183
+
184
+ return color
185
+
186
+ image_data = dict()
187
+ if self.gpu_sim_enabled:
188
+ if self.parallel_in_single_scene:
189
+ for name, camera in self.human_render_cameras.items():
190
+ camera.camera._render_cameras[0].take_picture()
191
+ rgba = get_rgba_tensor(camera, return_alpha)
192
+ image_data[name] = rgba
193
+ else:
194
+ for name, camera in self.human_render_cameras.items():
195
+ if camera_name is not None and name != camera_name:
196
+ continue
197
+ assert camera.config.shader_config.shader_pack not in [
198
+ "rt",
199
+ "rt-fast",
200
+ "rt-med",
201
+ ], "ray tracing shaders do not work with parallel rendering"
202
+ camera.capture()
203
+ rgba = get_rgba_tensor(camera, return_alpha)
204
+ image_data[name] = rgba
205
+ else:
206
+ for name, camera in self.human_render_cameras.items():
207
+ if camera_name is not None and name != camera_name:
208
+ continue
209
+ camera.capture()
210
+ rgba = get_rgba_tensor(camera, return_alpha)
211
+ image_data[name] = rgba
212
+
213
+ return image_data
214
+
215
+ ManiSkillScene.get_sensor_images = get_sensor_images
216
+ ManiSkillScene.get_human_render_camera_images = (
217
+ get_human_render_camera_images
218
+ )
embodied_gen/utils/process_media.py CHANGED
@@ -166,7 +166,7 @@ def combine_images_to_grid(
166
  images: list[str | Image.Image],
167
  cat_row_col: tuple[int, int] = None,
168
  target_wh: tuple[int, int] = (512, 512),
169
- ) -> list[str | Image.Image]:
170
  n_images = len(images)
171
  if n_images == 1:
172
  return images
@@ -377,6 +377,42 @@ def parse_text_prompts(prompts: list[str]) -> list[str]:
377
  return prompts
378
 
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  def check_object_edge_truncated(
381
  mask: np.ndarray, edge_threshold: int = 5
382
  ) -> bool:
@@ -400,8 +436,15 @@ def check_object_edge_truncated(
400
 
401
 
402
  if __name__ == "__main__":
403
- merge_video_video(
404
- "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
405
- "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
406
- "merge.mp4",
407
- )
 
 
 
 
 
 
 
 
166
  images: list[str | Image.Image],
167
  cat_row_col: tuple[int, int] = None,
168
  target_wh: tuple[int, int] = (512, 512),
169
+ ) -> list[Image.Image]:
170
  n_images = len(images)
171
  if n_images == 1:
172
  return images
 
377
  return prompts
378
 
379
 
380
+ def alpha_blend_rgba(
381
+ fg_image: Union[str, Image.Image, np.ndarray],
382
+ bg_image: Union[str, Image.Image, np.ndarray],
383
+ ) -> Image.Image:
384
+ """Alpha blends a foreground RGBA image over a background RGBA image.
385
+
386
+ Args:
387
+ fg_image: Foreground image. Can be a file path (str), a PIL Image,
388
+ or a NumPy ndarray.
389
+ bg_image: Background image. Can be a file path (str), a PIL Image,
390
+ or a NumPy ndarray.
391
+
392
+ Returns:
393
+ A PIL Image representing the alpha-blended result in RGBA mode.
394
+ """
395
+ if isinstance(fg_image, str):
396
+ fg_image = Image.open(fg_image)
397
+ elif isinstance(fg_image, np.ndarray):
398
+ fg_image = Image.fromarray(fg_image)
399
+
400
+ if isinstance(bg_image, str):
401
+ bg_image = Image.open(bg_image)
402
+ elif isinstance(bg_image, np.ndarray):
403
+ bg_image = Image.fromarray(bg_image)
404
+
405
+ if fg_image.size != bg_image.size:
406
+ raise ValueError(
407
+ f"Image sizes not match {fg_image.size} v.s. {bg_image.size}."
408
+ )
409
+
410
+ fg = fg_image.convert("RGBA")
411
+ bg = bg_image.convert("RGBA")
412
+
413
+ return Image.alpha_composite(bg, fg)
414
+
415
+
416
  def check_object_edge_truncated(
417
  mask: np.ndarray, edge_threshold: int = 5
418
  ) -> bool:
 
436
 
437
 
438
  if __name__ == "__main__":
439
+ image_paths = [
440
+ "outputs/layouts_sim/task_0000/images/pen.png",
441
+ "outputs/layouts_sim/task_0000/images/notebook.png",
442
+ "outputs/layouts_sim/task_0000/images/mug.png",
443
+ "outputs/layouts_sim/task_0000/images/lamp.png",
444
+ "outputs/layouts_sim2/task_0014/images/cloth.png", # TODO
445
+ ]
446
+ for image_path in image_paths:
447
+ image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
448
+ mask = image[..., -1]
449
+ flag = check_object_edge_truncated(mask)
450
+ print(flag, image_path)
embodied_gen/utils/simulation.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import json
18
+ import logging
19
+ import os
20
+ import xml.etree.ElementTree as ET
21
+ from collections import defaultdict
22
+ from typing import Literal
23
+
24
+ import mplib
25
+ import numpy as np
26
+ import sapien.core as sapien
27
+ import sapien.physx as physx
28
+ import torch
29
+ from mani_skill.agents.base_agent import BaseAgent
30
+ from mani_skill.envs.scene import ManiSkillScene
31
+ from mani_skill.examples.motionplanning.panda.utils import (
32
+ compute_grasp_info_by_obb,
33
+ )
34
+ from mani_skill.utils.geometry.trimesh_utils import get_component_mesh
35
+ from PIL import Image, ImageColor
36
+ from scipy.spatial.transform import Rotation as R
37
+ from embodied_gen.data.utils import DiffrastRender
38
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
39
+ from embodied_gen.utils.geometry import quaternion_multiply
40
+ from embodied_gen.utils.log import logger
41
+
42
+ COLORMAP = list(set(ImageColor.colormap.values()))
43
+ COLOR_PALETTE = np.array(
44
+ [ImageColor.getrgb(c) for c in COLORMAP], dtype=np.uint8
45
+ )
46
+ SIM_COORD_ALIGN = np.array(
47
+ [
48
+ [1.0, 0.0, 0.0, 0.0],
49
+ [0.0, -1.0, 0.0, 0.0],
50
+ [0.0, 0.0, -1.0, 0.0],
51
+ [0.0, 0.0, 0.0, 1.0],
52
+ ]
53
+ ) # Used to align SAPIEN, MuJoCo coordinate system with the world coordinate system
54
+
55
+ __all__ = [
56
+ "SIM_COORD_ALIGN",
57
+ "FrankaPandaGrasper",
58
+ "load_assets_from_layout_file",
59
+ "load_mani_skill_robot",
60
+ "render_images",
61
+ ]
62
+
63
+
64
+ def load_actor_from_urdf(
65
+ scene: ManiSkillScene | sapien.Scene,
66
+ file_path: str,
67
+ pose: sapien.Pose,
68
+ env_idx: int = None,
69
+ use_static: bool = False,
70
+ update_mass: bool = False,
71
+ ) -> sapien.pysapien.Entity:
72
+ tree = ET.parse(file_path)
73
+ root = tree.getroot()
74
+ node_name = root.get("name")
75
+ file_dir = os.path.dirname(file_path)
76
+ visual_file = root.find('.//visual/geometry/mesh').get("filename")
77
+ collision_file = root.find('.//collision/geometry/mesh').get("filename")
78
+ visual_file = os.path.join(file_dir, visual_file)
79
+ collision_file = os.path.join(file_dir, collision_file)
80
+ static_fric = root.find('.//collision/gazebo/mu1').text
81
+ dynamic_fric = root.find('.//collision/gazebo/mu2').text
82
+
83
+ material = physx.PhysxMaterial(
84
+ static_friction=np.clip(float(static_fric), 0.1, 0.7),
85
+ dynamic_friction=np.clip(float(dynamic_fric), 0.1, 0.6),
86
+ restitution=0.05,
87
+ )
88
+ builder = scene.create_actor_builder()
89
+
90
+ body_type = "static" if use_static else "dynamic"
91
+ builder.set_physx_body_type(body_type)
92
+ builder.add_multiple_convex_collisions_from_file(
93
+ collision_file if body_type == "dynamic" else visual_file,
94
+ material=material,
95
+ # decomposition="coacd",
96
+ # decomposition_params=dict(
97
+ # threshold=0.05, max_convex_hull=64, verbose=False
98
+ # ),
99
+ )
100
+
101
+ builder.add_visual_from_file(visual_file)
102
+ builder.set_initial_pose(pose)
103
+ if isinstance(scene, ManiSkillScene) and env_idx is not None:
104
+ builder.set_scene_idxs([env_idx])
105
+
106
+ actor = builder.build(name=f"{node_name}-{env_idx}")
107
+
108
+ if update_mass and hasattr(actor.components[1], "mass"):
109
+ node_mass = float(root.find('.//inertial/mass').get("value"))
110
+ actor.components[1].set_mass(node_mass)
111
+
112
+ return actor
113
+
114
+
115
+ def load_assets_from_layout_file(
116
+ scene: ManiSkillScene | sapien.Scene,
117
+ layout: LayoutInfo | str,
118
+ z_offset: float = 0.0,
119
+ init_quat: list[float] = [0, 0, 0, 1],
120
+ env_idx: int = None,
121
+ ) -> dict[str, sapien.pysapien.Entity]:
122
+ """Load assets from `EmbodiedGen` layout-gen output and create actors in the scene.
123
+
124
+ Args:
125
+ scene (sapien.Scene | ManiSkillScene): The SAPIEN or ManiSkill scene to load assets into.
126
+ layout (LayoutInfo): The layout information data.
127
+ z_offset (float): Offset to apply to the Z-coordinate of non-context objects.
128
+ init_quat (List[float]): Initial quaternion (x, y, z, w) for orientation adjustment.
129
+ env_idx (int): Environment index for multi-environment setup.
130
+ """
131
+ if isinstance(layout, str) and layout.endswith(".json"):
132
+ layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
133
+
134
+ actors = dict()
135
+ for node in layout.assets:
136
+ file_dir = layout.assets[node]
137
+ file_name = f"{node.replace(' ', '_')}.urdf"
138
+ urdf_file = os.path.join(file_dir, file_name)
139
+
140
+ if layout.objs_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
141
+ continue
142
+
143
+ position = layout.position[node].copy()
144
+ if layout.objs_mapping[node] != Scene3DItemEnum.CONTEXT.value:
145
+ position[2] += z_offset
146
+
147
+ use_static = (
148
+ layout.relation.get(Scene3DItemEnum.CONTEXT.value, None) == node
149
+ )
150
+
151
+ # Combine initial quaternion with object quaternion
152
+ x, y, z, qx, qy, qz, qw = position
153
+ qx, qy, qz, qw = quaternion_multiply([qx, qy, qz, qw], init_quat)
154
+ actor = load_actor_from_urdf(
155
+ scene,
156
+ urdf_file,
157
+ sapien.Pose(p=[x, y, z], q=[qw, qx, qy, qz]),
158
+ env_idx,
159
+ use_static=use_static,
160
+ update_mass=False,
161
+ )
162
+ actors[node] = actor
163
+
164
+ return actors
165
+
166
+
167
+ def load_mani_skill_robot(
168
+ scene: sapien.Scene | ManiSkillScene,
169
+ layout: LayoutInfo | str,
170
+ control_freq: int = 20,
171
+ robot_init_qpos_noise: float = 0.0,
172
+ control_mode: str = "pd_joint_pos",
173
+ backend_str: tuple[str, str] = ("cpu", "gpu"),
174
+ ) -> BaseAgent:
175
+ from mani_skill.agents import REGISTERED_AGENTS
176
+ from mani_skill.envs.scene import ManiSkillScene
177
+ from mani_skill.envs.utils.system.backend import (
178
+ parse_sim_and_render_backend,
179
+ )
180
+
181
+ if isinstance(layout, str) and layout.endswith(".json"):
182
+ layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
183
+
184
+ robot_name = layout.relation[Scene3DItemEnum.ROBOT.value]
185
+ x, y, z, qx, qy, qz, qw = layout.position[robot_name]
186
+ delta_z = 0.002 # Add small offset to avoid collision.
187
+ pose = sapien.Pose([x, y, z + delta_z], [qw, qx, qy, qz])
188
+
189
+ if robot_name not in REGISTERED_AGENTS:
190
+ logger.warning(
191
+ f"Robot `{robot_name}` not registered, chosen from {REGISTERED_AGENTS.keys()}, use `panda` instead."
192
+ )
193
+ robot_name = "panda"
194
+
195
+ ROBOT_CLS = REGISTERED_AGENTS[robot_name].agent_cls
196
+ backend = parse_sim_and_render_backend(*backend_str)
197
+ if isinstance(scene, sapien.Scene):
198
+ scene = ManiSkillScene([scene], device=backend_str[0], backend=backend)
199
+ robot = ROBOT_CLS(
200
+ scene=scene,
201
+ control_freq=control_freq,
202
+ control_mode=control_mode,
203
+ initial_pose=pose,
204
+ )
205
+
206
+ # Set robot init joint rad agree(joint0 to joint6 w 2 finger).
207
+ qpos = np.array(
208
+ [
209
+ 0.0,
210
+ np.pi / 8,
211
+ 0,
212
+ -np.pi * 3 / 8,
213
+ 0,
214
+ np.pi * 3 / 4,
215
+ np.pi / 4,
216
+ 0.04,
217
+ 0.04,
218
+ ]
219
+ )
220
+ qpos = (
221
+ np.random.normal(
222
+ 0, robot_init_qpos_noise, (len(scene.sub_scenes), len(qpos))
223
+ )
224
+ + qpos
225
+ )
226
+ qpos[:, -2:] = 0.04
227
+ robot.reset(qpos)
228
+ robot.init_qpos = robot.robot.qpos
229
+ robot.controller.controllers["gripper"].reset()
230
+
231
+ return robot
232
+
233
+
234
+ def render_images(
235
+ camera: sapien.render.RenderCameraComponent,
236
+ render_keys: list[
237
+ Literal[
238
+ "Color",
239
+ "Segmentation",
240
+ "Normal",
241
+ "Mask",
242
+ "Depth",
243
+ "Foreground",
244
+ ]
245
+ ] = None,
246
+ ) -> dict[str, Image.Image]:
247
+ """Render images from a given sapien camera.
248
+
249
+ Args:
250
+ camera (sapien.render.RenderCameraComponent): The camera to render from.
251
+ render_keys (List[str]): Types of images to render (e.g., Color, Segmentation).
252
+
253
+ Returns:
254
+ Dict[str, Image.Image]: Dictionary of rendered images.
255
+ """
256
+ if render_keys is None:
257
+ render_keys = [
258
+ "Color",
259
+ "Segmentation",
260
+ "Normal",
261
+ "Mask",
262
+ "Depth",
263
+ "Foreground",
264
+ ]
265
+
266
+ results: dict[str, Image.Image] = {}
267
+ if "Color" in render_keys:
268
+ color = camera.get_picture("Color")
269
+ color_rgb = (np.clip(color[..., :3], 0, 1) * 255).astype(np.uint8)
270
+ results["Color"] = Image.fromarray(color_rgb)
271
+
272
+ if "Mask" in render_keys:
273
+ alpha = (np.clip(color[..., 3], 0, 1) * 255).astype(np.uint8)
274
+ results["Mask"] = Image.fromarray(alpha)
275
+
276
+ if "Segmentation" in render_keys:
277
+ seg_labels = camera.get_picture("Segmentation")
278
+ label0 = seg_labels[..., 0].astype(np.uint8)
279
+ seg_color = COLOR_PALETTE[label0]
280
+ results["Segmentation"] = Image.fromarray(seg_color)
281
+
282
+ if "Foreground" in render_keys:
283
+ seg_labels = camera.get_picture("Segmentation")
284
+ label0 = seg_labels[..., 0]
285
+ mask = np.where((label0 > 1), 255, 0).astype(np.uint8)
286
+ color = camera.get_picture("Color")
287
+ color_rgb = (np.clip(color[..., :3], 0, 1) * 255).astype(np.uint8)
288
+ foreground = np.concatenate([color_rgb, mask[..., None]], axis=-1)
289
+ results["Foreground"] = Image.fromarray(foreground)
290
+
291
+ if "Normal" in render_keys:
292
+ normal = camera.get_picture("Normal")[..., :3]
293
+ normal_img = (((normal + 1) / 2) * 255).astype(np.uint8)
294
+ results["Normal"] = Image.fromarray(normal_img)
295
+
296
+ if "Depth" in render_keys:
297
+ position_map = camera.get_picture("Position")
298
+ depth = -position_map[..., 2]
299
+ alpha = torch.tensor(color[..., 3], dtype=torch.float32)
300
+ norm_depth = DiffrastRender.normalize_map_by_mask(
301
+ torch.tensor(depth), alpha
302
+ )
303
+ depth_img = (norm_depth * 255).to(torch.uint8).numpy()
304
+ results["Depth"] = Image.fromarray(depth_img)
305
+
306
+ return results
307
+
308
+
309
+ class SapienSceneManager:
310
+ """A class to manage SAPIEN simulator."""
311
+
312
+ def __init__(
313
+ self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
314
+ ) -> None:
315
+ self.sim_freq = sim_freq
316
+ self.ray_tracing = ray_tracing
317
+ self.device = device
318
+ self.renderer = sapien.SapienRenderer()
319
+ self.scene = self._setup_scene()
320
+ self.cameras: list[sapien.render.RenderCameraComponent] = []
321
+ self.actors: dict[str, sapien.pysapien.Entity] = {}
322
+
323
+ def _setup_scene(self) -> sapien.Scene:
324
+ """Set up the SAPIEN scene with lighting and ground."""
325
+ # Ray tracing settings
326
+ if self.ray_tracing:
327
+ sapien.render.set_camera_shader_dir("rt")
328
+ sapien.render.set_ray_tracing_samples_per_pixel(64)
329
+ sapien.render.set_ray_tracing_path_depth(10)
330
+ sapien.render.set_ray_tracing_denoiser("oidn")
331
+
332
+ scene = sapien.Scene()
333
+ scene.set_timestep(1 / self.sim_freq)
334
+
335
+ # Add lighting
336
+ scene.set_ambient_light([0.2, 0.2, 0.2])
337
+ scene.add_directional_light(
338
+ direction=[0, 1, -1],
339
+ color=[1.5, 1.45, 1.4],
340
+ shadow=True,
341
+ shadow_map_size=2048,
342
+ )
343
+ scene.add_directional_light(
344
+ direction=[0, -0.5, 1], color=[0.8, 0.8, 0.85], shadow=False
345
+ )
346
+ scene.add_directional_light(
347
+ direction=[0, -1, 1], color=[1.0, 1.0, 1.0], shadow=False
348
+ )
349
+
350
+ ground_material = self.renderer.create_material()
351
+ ground_material.base_color = [0.5, 0.5, 0.5, 1] # rgba, gray
352
+ ground_material.roughness = 0.7
353
+ ground_material.metallic = 0.0
354
+ scene.add_ground(0, render_material=ground_material)
355
+
356
+ return scene
357
+
358
+ def step_action(
359
+ self,
360
+ agent: BaseAgent,
361
+ action: torch.Tensor,
362
+ cameras: list[sapien.render.RenderCameraComponent],
363
+ render_keys: list[str],
364
+ sim_steps_per_control: int = 1,
365
+ ) -> dict:
366
+ agent.set_action(action)
367
+ frames = defaultdict(list)
368
+ for _ in range(sim_steps_per_control):
369
+ self.scene.step()
370
+
371
+ self.scene.update_render()
372
+ for camera in cameras:
373
+ camera.take_picture()
374
+ images = render_images(camera, render_keys=render_keys)
375
+ frames[camera.name].append(images)
376
+
377
+ return frames
378
+
379
+ def create_camera(
380
+ self,
381
+ cam_name: str,
382
+ pose: sapien.Pose,
383
+ image_hw: tuple[int, int],
384
+ fovy_deg: float,
385
+ ) -> sapien.render.RenderCameraComponent:
386
+ """Create a single camera in the scene.
387
+
388
+ Args:
389
+ cam_name (str): Name of the camera.
390
+ pose (sapien.Pose): Camera pose p=(x, y, z), q=(w, x, y, z)
391
+ image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
392
+ fovy_deg (float): Field of view in degrees for cameras.
393
+
394
+ Returns:
395
+ sapien.render.RenderCameraComponent: The created camera.
396
+ """
397
+ cam_actor = self.scene.create_actor_builder().build_kinematic()
398
+ cam_actor.set_pose(pose)
399
+ camera = self.scene.add_mounted_camera(
400
+ name=cam_name,
401
+ mount=cam_actor,
402
+ pose=sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0]),
403
+ width=image_hw[1],
404
+ height=image_hw[0],
405
+ fovy=np.deg2rad(fovy_deg),
406
+ near=0.01,
407
+ far=100,
408
+ )
409
+ self.cameras.append(camera)
410
+
411
+ return camera
412
+
413
+ def initialize_circular_cameras(
414
+ self,
415
+ num_cameras: int,
416
+ radius: float,
417
+ height: float,
418
+ target_pt: list[float],
419
+ image_hw: tuple[int, int],
420
+ fovy_deg: float,
421
+ ) -> list[sapien.render.RenderCameraComponent]:
422
+ """Initialize multiple cameras arranged in a circle.
423
+
424
+ Args:
425
+ num_cameras (int): Number of cameras to create.
426
+ radius (float): Radius of the camera circle.
427
+ height (float): Fixed Z-coordinate of the cameras.
428
+ target_pt (list[float]): 3D point (x, y, z) that cameras look at.
429
+ image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
430
+ fovy_deg (float): Field of view in degrees for cameras.
431
+
432
+ Returns:
433
+ List[sapien.render.RenderCameraComponent]: List of created cameras.
434
+ """
435
+ angle_step = 2 * np.pi / num_cameras
436
+ world_up_vec = np.array([0.0, 0.0, 1.0])
437
+ target_pt = np.array(target_pt)
438
+
439
+ for i in range(num_cameras):
440
+ angle = i * angle_step
441
+ cam_x = radius * np.cos(angle)
442
+ cam_y = radius * np.sin(angle)
443
+ cam_z = height
444
+ eye_pos = [cam_x, cam_y, cam_z]
445
+
446
+ forward_vec = target_pt - eye_pos
447
+ forward_vec = forward_vec / np.linalg.norm(forward_vec)
448
+ temp_right_vec = np.cross(forward_vec, world_up_vec)
449
+
450
+ if np.linalg.norm(temp_right_vec) < 1e-6:
451
+ temp_right_vec = np.array([1.0, 0.0, 0.0])
452
+ if np.abs(np.dot(temp_right_vec, forward_vec)) > 0.99:
453
+ temp_right_vec = np.array([0.0, 1.0, 0.0])
454
+
455
+ right_vec = temp_right_vec / np.linalg.norm(temp_right_vec)
456
+ up_vec = np.cross(right_vec, forward_vec)
457
+ rotation_matrix = np.array([forward_vec, -right_vec, up_vec]).T
458
+
459
+ rot = R.from_matrix(rotation_matrix)
460
+ scipy_quat = rot.as_quat() # (x, y, z, w)
461
+ quat = [
462
+ scipy_quat[3],
463
+ scipy_quat[0],
464
+ scipy_quat[1],
465
+ scipy_quat[2],
466
+ ] # (w, x, y, z)
467
+
468
+ self.create_camera(
469
+ f"camera_{i}",
470
+ sapien.Pose(p=eye_pos, q=quat),
471
+ image_hw,
472
+ fovy_deg,
473
+ )
474
+
475
+ return self.cameras
476
+
477
+
478
+ class FrankaPandaGrasper(object):
479
+ def __init__(
480
+ self,
481
+ agent: BaseAgent,
482
+ control_freq: float,
483
+ joint_vel_limits: float = 2.0,
484
+ joint_acc_limits: float = 1.0,
485
+ finger_length: float = 0.025,
486
+ ) -> None:
487
+ self.agent = agent
488
+ self.robot = agent.robot
489
+ self.control_freq = control_freq
490
+ self.control_timestep = 1 / control_freq
491
+ self.joint_vel_limits = joint_vel_limits
492
+ self.joint_acc_limits = joint_acc_limits
493
+ self.finger_length = finger_length
494
+ self.planners = self._setup_planner()
495
+
496
+ def _setup_planner(self) -> mplib.Planner:
497
+ planners = []
498
+ for pose in self.robot.pose:
499
+ link_names = [link.get_name() for link in self.robot.get_links()]
500
+ joint_names = [
501
+ joint.get_name() for joint in self.robot.get_active_joints()
502
+ ]
503
+ planner = mplib.Planner(
504
+ urdf=self.agent.urdf_path,
505
+ srdf=self.agent.urdf_path.replace(".urdf", ".srdf"),
506
+ user_link_names=link_names,
507
+ user_joint_names=joint_names,
508
+ move_group="panda_hand_tcp",
509
+ joint_vel_limits=np.ones(7) * self.joint_vel_limits,
510
+ joint_acc_limits=np.ones(7) * self.joint_acc_limits,
511
+ )
512
+ planner.set_base_pose(pose.raw_pose[0].tolist())
513
+ planners.append(planner)
514
+
515
+ return planners
516
+
517
+ def control_gripper(
518
+ self,
519
+ gripper_state: Literal[-1, 1],
520
+ n_step: int = 10,
521
+ ) -> np.ndarray:
522
+ qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
523
+ actions = []
524
+ for _ in range(n_step):
525
+ action = np.hstack([qpos, gripper_state])[None, ...]
526
+ actions.append(action)
527
+
528
+ return np.concatenate(actions, axis=0)
529
+
530
+ def move_to_pose(
531
+ self,
532
+ pose: sapien.Pose,
533
+ control_timestep: float,
534
+ gripper_state: Literal[-1, 1],
535
+ use_point_cloud: bool = False,
536
+ n_max_step: int = 100,
537
+ action_key: str = "position",
538
+ env_idx: int = 0,
539
+ ) -> np.ndarray:
540
+ result = self.planners[env_idx].plan_qpos_to_pose(
541
+ np.concatenate([pose.p, pose.q]),
542
+ self.robot.get_qpos().cpu().numpy()[0],
543
+ time_step=control_timestep,
544
+ use_point_cloud=use_point_cloud,
545
+ )
546
+
547
+ if result["status"] != "Success":
548
+ result = self.planners[env_idx].plan_screw(
549
+ np.concatenate([pose.p, pose.q]),
550
+ self.robot.get_qpos().cpu().numpy()[0],
551
+ time_step=control_timestep,
552
+ use_point_cloud=use_point_cloud,
553
+ )
554
+
555
+ if result["status"] != "Success":
556
+ return
557
+
558
+ sample_ratio = (len(result[action_key]) // n_max_step) + 1
559
+ result[action_key] = result[action_key][::sample_ratio]
560
+
561
+ n_step = len(result[action_key])
562
+ actions = []
563
+ for i in range(n_step):
564
+ qpos = result[action_key][i]
565
+ action = np.hstack([qpos, gripper_state])[None, ...]
566
+ actions.append(action)
567
+
568
+ return np.concatenate(actions, axis=0)
569
+
570
+ def compute_grasp_action(
571
+ self,
572
+ actor: sapien.pysapien.Entity,
573
+ reach_target_only: bool = True,
574
+ offset: tuple[float, float, float] = [0, 0, -0.05],
575
+ env_idx: int = 0,
576
+ ) -> np.ndarray:
577
+ physx_rigid = actor.components[1]
578
+ mesh = get_component_mesh(physx_rigid, to_world_frame=True)
579
+ obb = mesh.bounding_box_oriented
580
+ approaching = np.array([0, 0, -1])
581
+ tcp_pose = self.agent.tcp.pose[env_idx]
582
+ target_closing = (
583
+ tcp_pose.to_transformation_matrix()[0, :3, 1].cpu().numpy()
584
+ )
585
+ grasp_info = compute_grasp_info_by_obb(
586
+ obb,
587
+ approaching=approaching,
588
+ target_closing=target_closing,
589
+ depth=self.finger_length,
590
+ )
591
+
592
+ closing, center = grasp_info["closing"], grasp_info["center"]
593
+ raw_tcp_pose = tcp_pose.sp
594
+ grasp_pose = self.agent.build_grasp_pose(approaching, closing, center)
595
+ reach_pose = grasp_pose * sapien.Pose(p=offset)
596
+ grasp_pose = grasp_pose * sapien.Pose(p=[0, 0, 0.01])
597
+ actions = []
598
+ reach_actions = self.move_to_pose(
599
+ reach_pose,
600
+ self.control_timestep,
601
+ gripper_state=1,
602
+ env_idx=env_idx,
603
+ )
604
+ actions.append(reach_actions)
605
+
606
+ if reach_actions is None:
607
+ logger.warning(
608
+ f"Failed to reach the grasp pose for node `{actor.name}`, skipping grasping."
609
+ )
610
+ return None
611
+
612
+ if not reach_target_only:
613
+ grasp_actions = self.move_to_pose(
614
+ grasp_pose,
615
+ self.control_timestep,
616
+ gripper_state=1,
617
+ env_idx=env_idx,
618
+ )
619
+ actions.append(grasp_actions)
620
+ close_actions = self.control_gripper(
621
+ gripper_state=-1,
622
+ env_idx=env_idx,
623
+ )
624
+ actions.append(close_actions)
625
+ back_actions = self.move_to_pose(
626
+ raw_tcp_pose,
627
+ self.control_timestep,
628
+ gripper_state=-1,
629
+ env_idx=env_idx,
630
+ )
631
+ actions.append(back_actions)
632
+
633
+ return np.concatenate(actions, axis=0)
embodied_gen/utils/tags.py CHANGED
@@ -1 +1 @@
1
- VERSION = "v0.1.2"
 
1
+ VERSION = "v0.1.3"
embodied_gen/validators/quality_checkers.py CHANGED
@@ -109,7 +109,7 @@ class MeshGeoChecker(BaseChecker):
109
  if self.prompt is None:
110
  self.prompt = """
111
  You are an expert in evaluating the geometry quality of generated 3D asset.
112
- You will be given rendered views of a generated 3D asset with black background.
113
  Your task is to evaluate the quality of the 3D asset generation,
114
  including geometry, structure, and appearance, based on the rendered views.
115
  Criteria:
@@ -130,10 +130,13 @@ class MeshGeoChecker(BaseChecker):
130
  Image shows a chair with simplified back legs and soft edges β†’ YES
131
  """
132
 
133
- def query(self, image_paths: list[str | Image.Image]) -> str:
 
 
 
134
 
135
  return self.gpt_client.query(
136
- text_prompt=self.prompt,
137
  image_base64=image_paths,
138
  )
139
 
 
109
  if self.prompt is None:
110
  self.prompt = """
111
  You are an expert in evaluating the geometry quality of generated 3D asset.
112
+ You will be given rendered views of a generated 3D asset, type {}, with black background.
113
  Your task is to evaluate the quality of the 3D asset generation,
114
  including geometry, structure, and appearance, based on the rendered views.
115
  Criteria:
 
130
  Image shows a chair with simplified back legs and soft edges β†’ YES
131
  """
132
 
133
+ def query(
134
+ self, image_paths: list[str | Image.Image], text: str = "unknown"
135
+ ) -> str:
136
+ input_prompt = self.prompt.format(text)
137
 
138
  return self.gpt_client.query(
139
+ text_prompt=input_prompt,
140
  image_base64=image_paths,
141
  )
142
 
embodied_gen/validators/urdf_convertor.py CHANGED
@@ -24,6 +24,7 @@ from xml.dom.minidom import parseString
24
 
25
  import numpy as np
26
  import trimesh
 
27
  from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
28
  from embodied_gen.utils.process_media import render_asset3d
29
  from embodied_gen.utils.tags import VERSION
@@ -84,6 +85,7 @@ class URDFGenerator(object):
84
  attrs_name: list[str] = None,
85
  render_dir: str = "urdf_renders",
86
  render_view_num: int = 4,
 
87
  ) -> None:
88
  if mesh_file_list is None:
89
  mesh_file_list = []
@@ -107,36 +109,37 @@ class URDFGenerator(object):
107
  already provided, use it directly), accurately describe this 3D object asset (within 15 words),
108
  Determine the pose of the object in the first image and estimate the true vertical height
109
  (vertical projection) range of the object (in meters), i.e., how tall the object appears from top
110
- to bottom in the front view (first) image. also weight range (unit: kilogram), the average
111
  static friction coefficient of the object relative to rubber and the average dynamic friction
112
- coefficient of the object relative to rubber. Return response format as shown in Output Example.
113
 
114
  Output Example:
115
  Category: cup
116
  Description: shiny golden cup with floral design
117
- Height: 0.1-0.15 m
 
118
  Weight: 0.3-0.6 kg
119
  Static friction coefficient: 0.6
120
  Dynamic friction coefficient: 0.5
121
 
122
- IMPORTANT: Estimating Vertical Height from the First (Front View) Image.
123
  - The "vertical height" refers to the real-world vertical size of the object
124
  as projected in the first image, aligned with the image's vertical axis.
125
  - For flat objects like plates or disks or book, if their face is visible in the front view,
126
  use the diameter as the vertical height. If the edge is visible, use the thickness instead.
127
  - This is not necessarily the full length of the object, but how tall it appears
128
- in the first image vertically, based on its pose and orientation.
129
- - For objects(e.g., spoons, forks, writing instruments etc.) at an angle showing in
130
- the first image, tilted at 45Β° will appear shorter vertically than when upright.
131
  Estimate the vertical projection of their real length based on its pose.
132
  For example:
133
- - A pen standing upright in the first view (aligned with the image's vertical axis)
134
  full body visible in the first image: β†’ vertical height β‰ˆ 0.14-0.20 m
135
- - A pen lying flat in the front view (showing thickness) β†’ vertical height β‰ˆ 0.018-0.025 m
136
  - Tilted pen in the first image (e.g., ~45Β° angle): vertical height β‰ˆ 0.07-0.12 m
137
- - Use the rest views(except the first image) to help determine the object's 3D pose and orientation.
138
  Assume the object is in real-world scale and estimate the approximate vertical height
139
- (in meters) based on how large it appears vertically in the first image.
140
  """
141
  )
142
 
@@ -155,6 +158,7 @@ class URDFGenerator(object):
155
  "gs_model",
156
  ]
157
  self.attrs_name = attrs_name
 
158
 
159
  def parse_response(self, response: str) -> dict[str, any]:
160
  lines = response.split("\n")
@@ -163,14 +167,14 @@ class URDFGenerator(object):
163
  description = lines[1].split(": ")[1]
164
  min_height, max_height = map(
165
  lambda x: float(x.strip().replace(",", "").split()[0]),
166
- lines[2].split(": ")[1].split("-"),
167
  )
168
  min_mass, max_mass = map(
169
  lambda x: float(x.strip().replace(",", "").split()[0]),
170
- lines[3].split(": ")[1].split("-"),
171
  )
172
- mu1 = float(lines[4].split(": ")[1].replace(",", ""))
173
- mu2 = float(lines[5].split(": ")[1].replace(",", ""))
174
 
175
  return {
176
  "category": category.lower(),
@@ -257,9 +261,24 @@ class URDFGenerator(object):
257
  # Update collision geometry
258
  collision = link.find("collision/geometry/mesh")
259
  if collision is not None:
260
- collision.set(
261
- "filename", os.path.join(self.output_mesh_dir, obj_name)
262
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  collision.set("scale", "1.0 1.0 1.0")
264
 
265
  # Update friction coefficients
 
24
 
25
  import numpy as np
26
  import trimesh
27
+ from embodied_gen.data.convex_decomposer import decompose_convex_mesh
28
  from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
29
  from embodied_gen.utils.process_media import render_asset3d
30
  from embodied_gen.utils.tags import VERSION
 
85
  attrs_name: list[str] = None,
86
  render_dir: str = "urdf_renders",
87
  render_view_num: int = 4,
88
+ decompose_convex: bool = False,
89
  ) -> None:
90
  if mesh_file_list is None:
91
  mesh_file_list = []
 
109
  already provided, use it directly), accurately describe this 3D object asset (within 15 words),
110
  Determine the pose of the object in the first image and estimate the true vertical height
111
  (vertical projection) range of the object (in meters), i.e., how tall the object appears from top
112
+ to bottom in the first image. also weight range (unit: kilogram), the average
113
  static friction coefficient of the object relative to rubber and the average dynamic friction
114
+ coefficient of the object relative to rubber. Return response in format as shown in Output Example.
115
 
116
  Output Example:
117
  Category: cup
118
  Description: shiny golden cup with floral design
119
+ Pose: <short_description_within_10_words>
120
+ Height: 0.10-0.15 m
121
  Weight: 0.3-0.6 kg
122
  Static friction coefficient: 0.6
123
  Dynamic friction coefficient: 0.5
124
 
125
+ IMPORTANT: Estimating Vertical Height from the First (Front View) Image and pose estimation based on all views.
126
  - The "vertical height" refers to the real-world vertical size of the object
127
  as projected in the first image, aligned with the image's vertical axis.
128
  - For flat objects like plates or disks or book, if their face is visible in the front view,
129
  use the diameter as the vertical height. If the edge is visible, use the thickness instead.
130
  - This is not necessarily the full length of the object, but how tall it appears
131
+ in the first image vertically, based on its pose and orientation estimation on all views.
132
+ - For objects(e.g., spoons, forks, writing instruments etc.) at an angle showing in images,
133
+ e.g., tilted at 45Β° will appear shorter vertically than when upright.
134
  Estimate the vertical projection of their real length based on its pose.
135
  For example:
136
+ - A pen standing upright in the first image (aligned with the image's vertical axis)
137
  full body visible in the first image: β†’ vertical height β‰ˆ 0.14-0.20 m
138
+ - A pen lying flat in the first image (showing thickness or as a dot) β†’ vertical height β‰ˆ 0.018-0.025 m
139
  - Tilted pen in the first image (e.g., ~45Β° angle): vertical height β‰ˆ 0.07-0.12 m
140
+ - Use the rest views to help determine the object's 3D pose and orientation.
141
  Assume the object is in real-world scale and estimate the approximate vertical height
142
+ based on the pose estimation and how large it appears vertically in the first image.
143
  """
144
  )
145
 
 
158
  "gs_model",
159
  ]
160
  self.attrs_name = attrs_name
161
+ self.decompose_convex = decompose_convex
162
 
163
  def parse_response(self, response: str) -> dict[str, any]:
164
  lines = response.split("\n")
 
167
  description = lines[1].split(": ")[1]
168
  min_height, max_height = map(
169
  lambda x: float(x.strip().replace(",", "").split()[0]),
170
+ lines[3].split(": ")[1].split("-"),
171
  )
172
  min_mass, max_mass = map(
173
  lambda x: float(x.strip().replace(",", "").split()[0]),
174
+ lines[4].split(": ")[1].split("-"),
175
  )
176
+ mu1 = float(lines[5].split(": ")[1].replace(",", ""))
177
+ mu2 = float(lines[6].split(": ")[1].replace(",", ""))
178
 
179
  return {
180
  "category": category.lower(),
 
261
  # Update collision geometry
262
  collision = link.find("collision/geometry/mesh")
263
  if collision is not None:
264
+ collision_mesh = os.path.join(self.output_mesh_dir, obj_name)
265
+ if self.decompose_convex:
266
+ try:
267
+ d_params = dict(
268
+ threshold=0.05, max_convex_hull=64, verbose=False
269
+ )
270
+ filename = f"{os.path.splitext(obj_name)[0]}_collision.ply"
271
+ output_path = os.path.join(mesh_folder, filename)
272
+ decompose_convex_mesh(
273
+ mesh_output_path, output_path, **d_params
274
+ )
275
+ collision_mesh = f"{self.output_mesh_dir}/{filename}"
276
+ except Exception as e:
277
+ logger.warning(
278
+ f"Convex decomposition failed for {output_path}, {e}."
279
+ "Use original mesh for collision computation."
280
+ )
281
+ collision.set("filename", collision_mesh)
282
  collision.set("scale", "1.0 1.0 1.0")
283
 
284
  # Update friction coefficients