Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		xinjie.wang
		
	commited on
		
		
					Commit 
							
							Β·
						
						be0ecc3
	
1
								Parent(s):
							
							6ea980c
								
update
Browse files- common.py +4 -2
 - embodied_gen/data/convex_decomposer.py +161 -0
 - embodied_gen/data/mesh_operator.py +7 -1
 - embodied_gen/envs/pick_embodiedgen.py +389 -0
 - embodied_gen/models/gs_model.py +6 -8
 - embodied_gen/models/layout.py +509 -0
 - embodied_gen/models/sr_model.py +1 -1
 - embodied_gen/models/text_model.py +3 -3
 - embodied_gen/models/texture_model.py +1 -1
 - embodied_gen/scripts/compose_layout.py +73 -0
 - embodied_gen/scripts/gen_layout.py +156 -0
 - embodied_gen/scripts/imageto3d.py +13 -3
 - embodied_gen/scripts/parallel_sim.py +148 -0
 - embodied_gen/scripts/simulate_sapien.py +195 -0
 - embodied_gen/scripts/textto3d.py +3 -1
 - embodied_gen/scripts/textto3d.sh +1 -0
 - embodied_gen/trainer/gsplat_trainer.py +1 -1
 - embodied_gen/trainer/pono2mesh_trainer.py +1 -1
 - embodied_gen/utils/config.py +12 -0
 - embodied_gen/utils/enum.py +1 -0
 - embodied_gen/utils/gaussian.py +5 -6
 - embodied_gen/utils/geometry.py +458 -0
 - embodied_gen/utils/monkey_patches.py +66 -0
 - embodied_gen/utils/process_media.py +49 -6
 - embodied_gen/utils/simulation.py +633 -0
 - embodied_gen/utils/tags.py +1 -1
 - embodied_gen/validators/quality_checkers.py +6 -3
 - embodied_gen/validators/urdf_convertor.py +37 -18
 - requirements.txt +1 -0
 
    	
        common.py
    CHANGED
    
    | 
         @@ -189,7 +189,7 @@ os.makedirs(TMP_DIR, exist_ok=True) 
     | 
|
| 189 | 
         
             
            lighting_css = """
         
     | 
| 190 | 
         
             
            <style>
         
     | 
| 191 | 
         
             
            #lighter_mesh canvas {
         
     | 
| 192 | 
         
            -
                filter: brightness(1. 
     | 
| 193 | 
         
             
            }
         
     | 
| 194 | 
         
             
            </style>
         
     | 
| 195 | 
         
             
            """
         
     | 
| 
         @@ -547,7 +547,9 @@ def extract_urdf( 
     | 
|
| 547 | 
         | 
| 548 | 
         
             
                # Convert to URDF and recover attrs by GPT.
         
     | 
| 549 | 
         
             
                filename = "sample"
         
     | 
| 550 | 
         
            -
                urdf_convertor = URDFGenerator( 
     | 
| 
         | 
|
| 
         | 
|
| 551 | 
         
             
                asset_attrs = {
         
     | 
| 552 | 
         
             
                    "version": VERSION,
         
     | 
| 553 | 
         
             
                    "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
         
     | 
| 
         | 
|
| 189 | 
         
             
            lighting_css = """
         
     | 
| 190 | 
         
             
            <style>
         
     | 
| 191 | 
         
             
            #lighter_mesh canvas {
         
     | 
| 192 | 
         
            +
                filter: brightness(1.9) !important;
         
     | 
| 193 | 
         
             
            }
         
     | 
| 194 | 
         
             
            </style>
         
     | 
| 195 | 
         
             
            """
         
     | 
| 
         | 
|
| 547 | 
         | 
| 548 | 
         
             
                # Convert to URDF and recover attrs by GPT.
         
     | 
| 549 | 
         
             
                filename = "sample"
         
     | 
| 550 | 
         
            +
                urdf_convertor = URDFGenerator(
         
     | 
| 551 | 
         
            +
                    GPT_CLIENT, render_view_num=4, decompose_convex=True
         
     | 
| 552 | 
         
            +
                )
         
     | 
| 553 | 
         
             
                asset_attrs = {
         
     | 
| 554 | 
         
             
                    "version": VERSION,
         
     | 
| 555 | 
         
             
                    "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
         
     | 
    	
        embodied_gen/data/convex_decomposer.py
    ADDED
    
    | 
         @@ -0,0 +1,161 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Project EmbodiedGen
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 14 | 
         
            +
            # implied. See the License for the specific language governing
         
     | 
| 15 | 
         
            +
            # permissions and limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import logging
         
     | 
| 18 | 
         
            +
            import multiprocessing as mp
         
     | 
| 19 | 
         
            +
            import os
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import coacd
         
     | 
| 22 | 
         
            +
            import trimesh
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            __all__ = [
         
     | 
| 27 | 
         
            +
                "decompose_convex_coacd",
         
     | 
| 28 | 
         
            +
                "decompose_convex_mesh",
         
     | 
| 29 | 
         
            +
                "decompose_convex_process",
         
     | 
| 30 | 
         
            +
            ]
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            def decompose_convex_coacd(
         
     | 
| 34 | 
         
            +
                filename: str, outfile: str, params: dict, verbose: bool = False
         
     | 
| 35 | 
         
            +
            ) -> None:
         
     | 
| 36 | 
         
            +
                coacd.set_log_level("info" if verbose else "warn")
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                mesh = trimesh.load(filename, force="mesh")
         
     | 
| 39 | 
         
            +
                mesh = coacd.Mesh(mesh.vertices, mesh.faces)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                result = coacd.run_coacd(mesh, **params)
         
     | 
| 42 | 
         
            +
                combined = sum([trimesh.Trimesh(*m) for m in result])
         
     | 
| 43 | 
         
            +
                combined.export(outfile)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def decompose_convex_mesh(
         
     | 
| 47 | 
         
            +
                filename: str,
         
     | 
| 48 | 
         
            +
                outfile: str,
         
     | 
| 49 | 
         
            +
                threshold: float = 0.05,
         
     | 
| 50 | 
         
            +
                max_convex_hull: int = -1,
         
     | 
| 51 | 
         
            +
                preprocess_mode: str = "auto",
         
     | 
| 52 | 
         
            +
                preprocess_resolution: int = 30,
         
     | 
| 53 | 
         
            +
                resolution: int = 2000,
         
     | 
| 54 | 
         
            +
                mcts_nodes: int = 20,
         
     | 
| 55 | 
         
            +
                mcts_iterations: int = 150,
         
     | 
| 56 | 
         
            +
                mcts_max_depth: int = 3,
         
     | 
| 57 | 
         
            +
                pca: bool = False,
         
     | 
| 58 | 
         
            +
                merge: bool = True,
         
     | 
| 59 | 
         
            +
                seed: int = 0,
         
     | 
| 60 | 
         
            +
                verbose: bool = False,
         
     | 
| 61 | 
         
            +
            ) -> str:
         
     | 
| 62 | 
         
            +
                """Decompose a mesh into convex parts using the CoACD algorithm."""
         
     | 
| 63 | 
         
            +
                coacd.set_log_level("info" if verbose else "warn")
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                if os.path.exists(outfile):
         
     | 
| 66 | 
         
            +
                    logger.warning(f"Output file {outfile} already exists, removing it.")
         
     | 
| 67 | 
         
            +
                    os.remove(outfile)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                params = dict(
         
     | 
| 70 | 
         
            +
                    threshold=threshold,
         
     | 
| 71 | 
         
            +
                    max_convex_hull=max_convex_hull,
         
     | 
| 72 | 
         
            +
                    preprocess_mode=preprocess_mode,
         
     | 
| 73 | 
         
            +
                    preprocess_resolution=preprocess_resolution,
         
     | 
| 74 | 
         
            +
                    resolution=resolution,
         
     | 
| 75 | 
         
            +
                    mcts_nodes=mcts_nodes,
         
     | 
| 76 | 
         
            +
                    mcts_iterations=mcts_iterations,
         
     | 
| 77 | 
         
            +
                    mcts_max_depth=mcts_max_depth,
         
     | 
| 78 | 
         
            +
                    pca=pca,
         
     | 
| 79 | 
         
            +
                    merge=merge,
         
     | 
| 80 | 
         
            +
                    seed=seed,
         
     | 
| 81 | 
         
            +
                )
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                try:
         
     | 
| 84 | 
         
            +
                    decompose_convex_coacd(filename, outfile, params, verbose)
         
     | 
| 85 | 
         
            +
                    if os.path.exists(outfile):
         
     | 
| 86 | 
         
            +
                        return outfile
         
     | 
| 87 | 
         
            +
                except Exception as e:
         
     | 
| 88 | 
         
            +
                    if verbose:
         
     | 
| 89 | 
         
            +
                        print(f"Decompose convex first attempt failed: {e}.")
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                if preprocess_mode != "on":
         
     | 
| 92 | 
         
            +
                    try:
         
     | 
| 93 | 
         
            +
                        params["preprocess_mode"] = "on"
         
     | 
| 94 | 
         
            +
                        decompose_convex_coacd(filename, outfile, params, verbose)
         
     | 
| 95 | 
         
            +
                        if os.path.exists(outfile):
         
     | 
| 96 | 
         
            +
                            return outfile
         
     | 
| 97 | 
         
            +
                    except Exception as e:
         
     | 
| 98 | 
         
            +
                        if verbose:
         
     | 
| 99 | 
         
            +
                            print(
         
     | 
| 100 | 
         
            +
                                f"Decompose convex second attempt with preprocess_mode='on' failed: {e}"
         
     | 
| 101 | 
         
            +
                            )
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                raise RuntimeError(f"Convex decomposition failed on {filename}")
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
            def decompose_convex_mp(
         
     | 
| 107 | 
         
            +
                filename: str,
         
     | 
| 108 | 
         
            +
                outfile: str,
         
     | 
| 109 | 
         
            +
                threshold: float = 0.05,
         
     | 
| 110 | 
         
            +
                max_convex_hull: int = -1,
         
     | 
| 111 | 
         
            +
                preprocess_mode: str = "auto",
         
     | 
| 112 | 
         
            +
                preprocess_resolution: int = 30,
         
     | 
| 113 | 
         
            +
                resolution: int = 2000,
         
     | 
| 114 | 
         
            +
                mcts_nodes: int = 20,
         
     | 
| 115 | 
         
            +
                mcts_iterations: int = 150,
         
     | 
| 116 | 
         
            +
                mcts_max_depth: int = 3,
         
     | 
| 117 | 
         
            +
                pca: bool = False,
         
     | 
| 118 | 
         
            +
                merge: bool = True,
         
     | 
| 119 | 
         
            +
                seed: int = 0,
         
     | 
| 120 | 
         
            +
                verbose: bool = False,
         
     | 
| 121 | 
         
            +
            ) -> str:
         
     | 
| 122 | 
         
            +
                """Decompose a mesh into convex parts using the CoACD algorithm in a separate process.
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
         
     | 
| 125 | 
         
            +
                """
         
     | 
| 126 | 
         
            +
                params = dict(
         
     | 
| 127 | 
         
            +
                    threshold=threshold,
         
     | 
| 128 | 
         
            +
                    max_convex_hull=max_convex_hull,
         
     | 
| 129 | 
         
            +
                    preprocess_mode=preprocess_mode,
         
     | 
| 130 | 
         
            +
                    preprocess_resolution=preprocess_resolution,
         
     | 
| 131 | 
         
            +
                    resolution=resolution,
         
     | 
| 132 | 
         
            +
                    mcts_nodes=mcts_nodes,
         
     | 
| 133 | 
         
            +
                    mcts_iterations=mcts_iterations,
         
     | 
| 134 | 
         
            +
                    mcts_max_depth=mcts_max_depth,
         
     | 
| 135 | 
         
            +
                    pca=pca,
         
     | 
| 136 | 
         
            +
                    merge=merge,
         
     | 
| 137 | 
         
            +
                    seed=seed,
         
     | 
| 138 | 
         
            +
                )
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                ctx = mp.get_context("spawn")
         
     | 
| 141 | 
         
            +
                p = ctx.Process(
         
     | 
| 142 | 
         
            +
                    target=decompose_convex_coacd,
         
     | 
| 143 | 
         
            +
                    args=(filename, outfile, params, verbose),
         
     | 
| 144 | 
         
            +
                )
         
     | 
| 145 | 
         
            +
                p.start()
         
     | 
| 146 | 
         
            +
                p.join()
         
     | 
| 147 | 
         
            +
                if p.exitcode == 0 and os.path.exists(outfile):
         
     | 
| 148 | 
         
            +
                    return outfile
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                if preprocess_mode != "on":
         
     | 
| 151 | 
         
            +
                    params["preprocess_mode"] = "on"
         
     | 
| 152 | 
         
            +
                    p = ctx.Process(
         
     | 
| 153 | 
         
            +
                        target=decompose_convex_coacd,
         
     | 
| 154 | 
         
            +
                        args=(filename, outfile, params, verbose),
         
     | 
| 155 | 
         
            +
                    )
         
     | 
| 156 | 
         
            +
                    p.start()
         
     | 
| 157 | 
         
            +
                    p.join()
         
     | 
| 158 | 
         
            +
                    if p.exitcode == 0 and os.path.exists(outfile):
         
     | 
| 159 | 
         
            +
                        return outfile
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                raise RuntimeError(f"Convex decomposition failed on {filename}")
         
     | 
    	
        embodied_gen/data/mesh_operator.py
    CHANGED
    
    | 
         @@ -16,13 +16,17 @@ 
     | 
|
| 16 | 
         | 
| 17 | 
         | 
| 18 | 
         
             
            import logging
         
     | 
| 
         | 
|
| 
         | 
|
| 19 | 
         
             
            from typing import Tuple, Union
         
     | 
| 20 | 
         | 
| 
         | 
|
| 21 | 
         
             
            import igraph
         
     | 
| 22 | 
         
             
            import numpy as np
         
     | 
| 23 | 
         
             
            import pyvista as pv
         
     | 
| 24 | 
         
             
            import spaces
         
     | 
| 25 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 26 | 
         
             
            import utils3d
         
     | 
| 27 | 
         
             
            from pymeshfix import _meshfix
         
     | 
| 28 | 
         
             
            from tqdm import tqdm
         
     | 
| 
         @@ -33,7 +37,9 @@ logging.basicConfig( 
     | 
|
| 33 | 
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 34 | 
         | 
| 35 | 
         | 
| 36 | 
         
            -
            __all__ = [ 
     | 
| 
         | 
|
| 
         | 
|
| 37 | 
         | 
| 38 | 
         | 
| 39 | 
         
             
            def _radical_inverse(base, n):
         
     | 
| 
         | 
|
| 16 | 
         | 
| 17 | 
         | 
| 18 | 
         
             
            import logging
         
     | 
| 19 | 
         
            +
            import multiprocessing as mp
         
     | 
| 20 | 
         
            +
            import os
         
     | 
| 21 | 
         
             
            from typing import Tuple, Union
         
     | 
| 22 | 
         | 
| 23 | 
         
            +
            import coacd
         
     | 
| 24 | 
         
             
            import igraph
         
     | 
| 25 | 
         
             
            import numpy as np
         
     | 
| 26 | 
         
             
            import pyvista as pv
         
     | 
| 27 | 
         
             
            import spaces
         
     | 
| 28 | 
         
             
            import torch
         
     | 
| 29 | 
         
            +
            import trimesh
         
     | 
| 30 | 
         
             
            import utils3d
         
     | 
| 31 | 
         
             
            from pymeshfix import _meshfix
         
     | 
| 32 | 
         
             
            from tqdm import tqdm
         
     | 
| 
         | 
|
| 37 | 
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 38 | 
         | 
| 39 | 
         | 
| 40 | 
         
            +
            __all__ = [
         
     | 
| 41 | 
         
            +
                "MeshFixer",
         
     | 
| 42 | 
         
            +
            ]
         
     | 
| 43 | 
         | 
| 44 | 
         | 
| 45 | 
         
             
            def _radical_inverse(base, n):
         
     | 
    	
        embodied_gen/envs/pick_embodiedgen.py
    ADDED
    
    | 
         @@ -0,0 +1,389 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Project EmbodiedGen
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 14 | 
         
            +
            # implied. See the License for the specific language governing
         
     | 
| 15 | 
         
            +
            # permissions and limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import json
         
     | 
| 18 | 
         
            +
            import os
         
     | 
| 19 | 
         
            +
            from copy import deepcopy
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import numpy as np
         
     | 
| 22 | 
         
            +
            import sapien
         
     | 
| 23 | 
         
            +
            import torch
         
     | 
| 24 | 
         
            +
            import torchvision.transforms as transforms
         
     | 
| 25 | 
         
            +
            from mani_skill.envs.sapien_env import BaseEnv
         
     | 
| 26 | 
         
            +
            from mani_skill.sensors.camera import CameraConfig
         
     | 
| 27 | 
         
            +
            from mani_skill.utils import sapien_utils
         
     | 
| 28 | 
         
            +
            from mani_skill.utils.building import actors
         
     | 
| 29 | 
         
            +
            from mani_skill.utils.registration import register_env
         
     | 
| 30 | 
         
            +
            from mani_skill.utils.structs.actor import Actor
         
     | 
| 31 | 
         
            +
            from mani_skill.utils.structs.pose import Pose
         
     | 
| 32 | 
         
            +
            from mani_skill.utils.structs.types import (
         
     | 
| 33 | 
         
            +
                GPUMemoryConfig,
         
     | 
| 34 | 
         
            +
                SceneConfig,
         
     | 
| 35 | 
         
            +
                SimConfig,
         
     | 
| 36 | 
         
            +
            )
         
     | 
| 37 | 
         
            +
            from mani_skill.utils.visualization.misc import tile_images
         
     | 
| 38 | 
         
            +
            from tqdm import tqdm
         
     | 
| 39 | 
         
            +
            from embodied_gen.models.gs_model import GaussianOperator
         
     | 
| 40 | 
         
            +
            from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
         
     | 
| 41 | 
         
            +
            from embodied_gen.utils.geometry import bfs_placement, quaternion_multiply
         
     | 
| 42 | 
         
            +
            from embodied_gen.utils.log import logger
         
     | 
| 43 | 
         
            +
            from embodied_gen.utils.process_media import alpha_blend_rgba
         
     | 
| 44 | 
         
            +
            from embodied_gen.utils.simulation import (
         
     | 
| 45 | 
         
            +
                SIM_COORD_ALIGN,
         
     | 
| 46 | 
         
            +
                load_assets_from_layout_file,
         
     | 
| 47 | 
         
            +
            )
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            __all__ = ["PickEmbodiedGen"]
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            @register_env("PickEmbodiedGen-v1", max_episode_steps=100)
         
     | 
| 53 | 
         
            +
            class PickEmbodiedGen(BaseEnv):
         
     | 
| 54 | 
         
            +
                SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
         
     | 
| 55 | 
         
            +
                goal_thresh = 0.0
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def __init__(
         
     | 
| 58 | 
         
            +
                    self,
         
     | 
| 59 | 
         
            +
                    *args,
         
     | 
| 60 | 
         
            +
                    robot_uids: str | list[str] = "panda",
         
     | 
| 61 | 
         
            +
                    robot_init_qpos_noise: float = 0.02,
         
     | 
| 62 | 
         
            +
                    num_envs: int = 1,
         
     | 
| 63 | 
         
            +
                    reconfiguration_freq: int = None,
         
     | 
| 64 | 
         
            +
                    **kwargs,
         
     | 
| 65 | 
         
            +
                ):
         
     | 
| 66 | 
         
            +
                    self.robot_init_qpos_noise = robot_init_qpos_noise
         
     | 
| 67 | 
         
            +
                    if reconfiguration_freq is None:
         
     | 
| 68 | 
         
            +
                        if num_envs == 1:
         
     | 
| 69 | 
         
            +
                            reconfiguration_freq = 1
         
     | 
| 70 | 
         
            +
                        else:
         
     | 
| 71 | 
         
            +
                            reconfiguration_freq = 0
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    # Init params from kwargs.
         
     | 
| 74 | 
         
            +
                    layout_file = kwargs.pop("layout_file", None)
         
     | 
| 75 | 
         
            +
                    replace_objs = kwargs.pop("replace_objs", True)
         
     | 
| 76 | 
         
            +
                    self.enable_grasp = kwargs.pop("enable_grasp", False)
         
     | 
| 77 | 
         
            +
                    self.init_quat = kwargs.pop("init_quat", [0.7071, 0, 0, 0.7071])
         
     | 
| 78 | 
         
            +
                    # Add small offset in z-axis to avoid collision.
         
     | 
| 79 | 
         
            +
                    self.objs_z_offset = kwargs.pop("objs_z_offset", 0.002)
         
     | 
| 80 | 
         
            +
                    self.robot_z_offset = kwargs.pop("robot_z_offset", 0.002)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    self.layouts = self.init_env_layouts(
         
     | 
| 83 | 
         
            +
                        layout_file, num_envs, replace_objs
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
                    self.robot_pose = self.compute_robot_init_pose(
         
     | 
| 86 | 
         
            +
                        self.layouts, num_envs, self.robot_z_offset
         
     | 
| 87 | 
         
            +
                    )
         
     | 
| 88 | 
         
            +
                    self.env_actors = dict()
         
     | 
| 89 | 
         
            +
                    self.image_transform = transforms.PILToTensor()
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    super().__init__(
         
     | 
| 92 | 
         
            +
                        *args,
         
     | 
| 93 | 
         
            +
                        robot_uids=robot_uids,
         
     | 
| 94 | 
         
            +
                        reconfiguration_freq=reconfiguration_freq,
         
     | 
| 95 | 
         
            +
                        num_envs=num_envs,
         
     | 
| 96 | 
         
            +
                        **kwargs,
         
     | 
| 97 | 
         
            +
                    )
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    self.bg_images = dict()
         
     | 
| 100 | 
         
            +
                    if self.render_mode == "hybrid":
         
     | 
| 101 | 
         
            +
                        self.bg_images = self.render_gs3d_images(
         
     | 
| 102 | 
         
            +
                            self.layouts, num_envs, self.init_quat
         
     | 
| 103 | 
         
            +
                        )
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                @staticmethod
         
     | 
| 106 | 
         
            +
                def init_env_layouts(
         
     | 
| 107 | 
         
            +
                    layout_file: str, num_envs: int, replace_objs: bool
         
     | 
| 108 | 
         
            +
                ) -> list[LayoutInfo]:
         
     | 
| 109 | 
         
            +
                    layout = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
         
     | 
| 110 | 
         
            +
                    layouts = []
         
     | 
| 111 | 
         
            +
                    for env_idx in range(num_envs):
         
     | 
| 112 | 
         
            +
                        if replace_objs and env_idx > 0:
         
     | 
| 113 | 
         
            +
                            layout = bfs_placement(deepcopy(layout))
         
     | 
| 114 | 
         
            +
                        layouts.append(layout)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    return layouts
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                @staticmethod
         
     | 
| 119 | 
         
            +
                def compute_robot_init_pose(
         
     | 
| 120 | 
         
            +
                    layouts: list[LayoutInfo], num_envs: int, z_offset: float = 0.0
         
     | 
| 121 | 
         
            +
                ) -> list[list[float]]:
         
     | 
| 122 | 
         
            +
                    robot_pose = []
         
     | 
| 123 | 
         
            +
                    for env_idx in range(num_envs):
         
     | 
| 124 | 
         
            +
                        layout = layouts[env_idx]
         
     | 
| 125 | 
         
            +
                        robot_node = layout.relation[Scene3DItemEnum.ROBOT.value]
         
     | 
| 126 | 
         
            +
                        x, y, z, qx, qy, qz, qw = layout.position[robot_node]
         
     | 
| 127 | 
         
            +
                        robot_pose.append([x, y, z + z_offset, qw, qx, qy, qz])
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    return robot_pose
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                @property
         
     | 
| 132 | 
         
            +
                def _default_sim_config(self):
         
     | 
| 133 | 
         
            +
                    return SimConfig(
         
     | 
| 134 | 
         
            +
                        scene_config=SceneConfig(
         
     | 
| 135 | 
         
            +
                            solver_position_iterations=30,
         
     | 
| 136 | 
         
            +
                            # contact_offset=0.04,
         
     | 
| 137 | 
         
            +
                            # rest_offset=0.001,
         
     | 
| 138 | 
         
            +
                        ),
         
     | 
| 139 | 
         
            +
                        # sim_freq=200,
         
     | 
| 140 | 
         
            +
                        control_freq=50,
         
     | 
| 141 | 
         
            +
                        gpu_memory_config=GPUMemoryConfig(
         
     | 
| 142 | 
         
            +
                            max_rigid_contact_count=2**20, max_rigid_patch_count=2**19
         
     | 
| 143 | 
         
            +
                        ),
         
     | 
| 144 | 
         
            +
                    )
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                @property
         
     | 
| 147 | 
         
            +
                def _default_sensor_configs(self):
         
     | 
| 148 | 
         
            +
                    pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    return [
         
     | 
| 151 | 
         
            +
                        CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)
         
     | 
| 152 | 
         
            +
                    ]
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                @property
         
     | 
| 155 | 
         
            +
                def _default_human_render_camera_configs(self):
         
     | 
| 156 | 
         
            +
                    pose = sapien_utils.look_at(
         
     | 
| 157 | 
         
            +
                        eye=[0.9, 0.0, 1.1], target=[0.0, 0.0, 0.9]
         
     | 
| 158 | 
         
            +
                    )
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    return CameraConfig(
         
     | 
| 161 | 
         
            +
                        "render_camera", pose, 256, 256, np.deg2rad(75), 0.01, 100
         
     | 
| 162 | 
         
            +
                    )
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def _load_agent(self, options: dict):
         
     | 
| 165 | 
         
            +
                    super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                def _load_scene(self, options: dict):
         
     | 
| 168 | 
         
            +
                    all_objects = []
         
     | 
| 169 | 
         
            +
                    logger.info(f"Loading assets and decomposition mesh collisions...")
         
     | 
| 170 | 
         
            +
                    for env_idx in range(self.num_envs):
         
     | 
| 171 | 
         
            +
                        env_actors = load_assets_from_layout_file(
         
     | 
| 172 | 
         
            +
                            self.scene,
         
     | 
| 173 | 
         
            +
                            self.layouts[env_idx],
         
     | 
| 174 | 
         
            +
                            z_offset=self.objs_z_offset,
         
     | 
| 175 | 
         
            +
                            init_quat=self.init_quat,
         
     | 
| 176 | 
         
            +
                            env_idx=env_idx,
         
     | 
| 177 | 
         
            +
                        )
         
     | 
| 178 | 
         
            +
                        self.env_actors[f"env{env_idx}"] = env_actors
         
     | 
| 179 | 
         
            +
                        all_objects.extend(env_actors.values())
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    self.obj = all_objects[-1]
         
     | 
| 182 | 
         
            +
                    for obj in all_objects:
         
     | 
| 183 | 
         
            +
                        self.remove_from_state_dict_registry(obj)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    self.all_objects = Actor.merge(all_objects, name="all_objects")
         
     | 
| 186 | 
         
            +
                    self.add_to_state_dict_registry(self.all_objects)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    self.goal_site = actors.build_sphere(
         
     | 
| 189 | 
         
            +
                        self.scene,
         
     | 
| 190 | 
         
            +
                        radius=self.goal_thresh,
         
     | 
| 191 | 
         
            +
                        color=[0, 1, 0, 0],
         
     | 
| 192 | 
         
            +
                        name="goal_site",
         
     | 
| 193 | 
         
            +
                        body_type="kinematic",
         
     | 
| 194 | 
         
            +
                        add_collision=False,
         
     | 
| 195 | 
         
            +
                        initial_pose=sapien.Pose(),
         
     | 
| 196 | 
         
            +
                    )
         
     | 
| 197 | 
         
            +
                    self._hidden_objects.append(self.goal_site)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
         
     | 
| 200 | 
         
            +
                    with torch.device(self.device):
         
     | 
| 201 | 
         
            +
                        b = len(env_idx)
         
     | 
| 202 | 
         
            +
                        goal_xyz = torch.zeros((b, 3))
         
     | 
| 203 | 
         
            +
                        goal_xyz[:, :2] = torch.rand((b, 2)) * 0.2 - 0.1
         
     | 
| 204 | 
         
            +
                        self.goal_site.set_pose(Pose.create_from_pq(goal_xyz))
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                        qpos = np.array(
         
     | 
| 207 | 
         
            +
                            [
         
     | 
| 208 | 
         
            +
                                0.0,
         
     | 
| 209 | 
         
            +
                                np.pi / 8,
         
     | 
| 210 | 
         
            +
                                0,
         
     | 
| 211 | 
         
            +
                                -np.pi * 3 / 8,
         
     | 
| 212 | 
         
            +
                                0,
         
     | 
| 213 | 
         
            +
                                np.pi * 3 / 4,
         
     | 
| 214 | 
         
            +
                                np.pi / 4,
         
     | 
| 215 | 
         
            +
                                0.04,
         
     | 
| 216 | 
         
            +
                                0.04,
         
     | 
| 217 | 
         
            +
                            ]
         
     | 
| 218 | 
         
            +
                        )
         
     | 
| 219 | 
         
            +
                        qpos = (
         
     | 
| 220 | 
         
            +
                            np.random.normal(
         
     | 
| 221 | 
         
            +
                                0, self.robot_init_qpos_noise, (self.num_envs, len(qpos))
         
     | 
| 222 | 
         
            +
                            )
         
     | 
| 223 | 
         
            +
                            + qpos
         
     | 
| 224 | 
         
            +
                        )
         
     | 
| 225 | 
         
            +
                        qpos[:, -2:] = 0.04
         
     | 
| 226 | 
         
            +
                        self.agent.robot.set_root_pose(np.array(self.robot_pose))
         
     | 
| 227 | 
         
            +
                        self.agent.reset(qpos)
         
     | 
| 228 | 
         
            +
                        self.agent.init_qpos = qpos
         
     | 
| 229 | 
         
            +
                        self.agent.controller.controllers["gripper"].reset()
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                def render_gs3d_images(
         
     | 
| 232 | 
         
            +
                    self, layouts: list[LayoutInfo], num_envs: int, init_quat: list[float]
         
     | 
| 233 | 
         
            +
                ) -> dict[str, np.ndarray]:
         
     | 
| 234 | 
         
            +
                    sim_coord_align = (
         
     | 
| 235 | 
         
            +
                        torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
         
     | 
| 236 | 
         
            +
                    )
         
     | 
| 237 | 
         
            +
                    cameras = self.scene.sensors.copy()
         
     | 
| 238 | 
         
            +
                    cameras.update(self.scene.human_render_cameras)
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    bg_node = layouts[0].relation[Scene3DItemEnum.BACKGROUND.value]
         
     | 
| 241 | 
         
            +
                    gs_path = os.path.join(layouts[0].assets[bg_node], "gs_model.ply")
         
     | 
| 242 | 
         
            +
                    raw_gs: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
         
     | 
| 243 | 
         
            +
                    bg_images = dict()
         
     | 
| 244 | 
         
            +
                    for env_idx in tqdm(range(num_envs), desc="Pre-rendering Background"):
         
     | 
| 245 | 
         
            +
                        layout = layouts[env_idx]
         
     | 
| 246 | 
         
            +
                        x, y, z, qx, qy, qz, qw = layout.position[bg_node]
         
     | 
| 247 | 
         
            +
                        qx, qy, qz, qw = quaternion_multiply([qx, qy, qz, qw], init_quat)
         
     | 
| 248 | 
         
            +
                        init_pose = torch.tensor([x, y, z, qx, qy, qz, qw])
         
     | 
| 249 | 
         
            +
                        gs_model = raw_gs.get_gaussians(instance_pose=init_pose)
         
     | 
| 250 | 
         
            +
                        for key in cameras:
         
     | 
| 251 | 
         
            +
                            camera = cameras[key]
         
     | 
| 252 | 
         
            +
                            Ks = camera.camera.get_intrinsic_matrix()  # (n_env, 3, 3)
         
     | 
| 253 | 
         
            +
                            c2w = camera.camera.get_model_matrix()  # (n_env, 4, 4)
         
     | 
| 254 | 
         
            +
                            result = gs_model.render(
         
     | 
| 255 | 
         
            +
                                c2w[env_idx] @ sim_coord_align,
         
     | 
| 256 | 
         
            +
                                Ks[env_idx],
         
     | 
| 257 | 
         
            +
                                image_width=camera.config.width,
         
     | 
| 258 | 
         
            +
                                image_height=camera.config.height,
         
     | 
| 259 | 
         
            +
                            )
         
     | 
| 260 | 
         
            +
                            bg_images[f"{key}-env{env_idx}"] = result.rgb[..., ::-1]
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    return bg_images
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                def render(self):
         
     | 
| 265 | 
         
            +
                    if self.render_mode is None:
         
     | 
| 266 | 
         
            +
                        raise RuntimeError("render_mode is not set.")
         
     | 
| 267 | 
         
            +
                    if self.render_mode == "human":
         
     | 
| 268 | 
         
            +
                        return self.render_human()
         
     | 
| 269 | 
         
            +
                    elif self.render_mode == "rgb_array":
         
     | 
| 270 | 
         
            +
                        res = self.render_rgb_array()
         
     | 
| 271 | 
         
            +
                        return res
         
     | 
| 272 | 
         
            +
                    elif self.render_mode == "sensors":
         
     | 
| 273 | 
         
            +
                        res = self.render_sensors()
         
     | 
| 274 | 
         
            +
                        return res
         
     | 
| 275 | 
         
            +
                    elif self.render_mode == "all":
         
     | 
| 276 | 
         
            +
                        return self.render_all()
         
     | 
| 277 | 
         
            +
                    elif self.render_mode == "hybrid":
         
     | 
| 278 | 
         
            +
                        return self.hybrid_render()
         
     | 
| 279 | 
         
            +
                    else:
         
     | 
| 280 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 281 | 
         
            +
                            f"Unsupported render mode {self.render_mode}."
         
     | 
| 282 | 
         
            +
                        )
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                def render_rgb_array(
         
     | 
| 285 | 
         
            +
                    self, camera_name: str = None, return_alpha: bool = False
         
     | 
| 286 | 
         
            +
                ):
         
     | 
| 287 | 
         
            +
                    for obj in self._hidden_objects:
         
     | 
| 288 | 
         
            +
                        obj.show_visual()
         
     | 
| 289 | 
         
            +
                    self.scene.update_render(
         
     | 
| 290 | 
         
            +
                        update_sensors=False, update_human_render_cameras=True
         
     | 
| 291 | 
         
            +
                    )
         
     | 
| 292 | 
         
            +
                    images = []
         
     | 
| 293 | 
         
            +
                    render_images = self.scene.get_human_render_camera_images(
         
     | 
| 294 | 
         
            +
                        camera_name, return_alpha
         
     | 
| 295 | 
         
            +
                    )
         
     | 
| 296 | 
         
            +
                    for image in render_images.values():
         
     | 
| 297 | 
         
            +
                        images.append(image)
         
     | 
| 298 | 
         
            +
                    if len(images) == 0:
         
     | 
| 299 | 
         
            +
                        return None
         
     | 
| 300 | 
         
            +
                    if len(images) == 1:
         
     | 
| 301 | 
         
            +
                        return images[0]
         
     | 
| 302 | 
         
            +
                    for obj in self._hidden_objects:
         
     | 
| 303 | 
         
            +
                        obj.hide_visual()
         
     | 
| 304 | 
         
            +
                    return tile_images(images)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                def render_sensors(self):
         
     | 
| 307 | 
         
            +
                    images = []
         
     | 
| 308 | 
         
            +
                    sensor_images = self.get_sensor_images()
         
     | 
| 309 | 
         
            +
                    for image in sensor_images.values():
         
     | 
| 310 | 
         
            +
                        for img in image.values():
         
     | 
| 311 | 
         
            +
                            images.append(img)
         
     | 
| 312 | 
         
            +
                    return tile_images(images)
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                def hybrid_render(self):
         
     | 
| 315 | 
         
            +
                    fg_images = self.render_rgb_array(
         
     | 
| 316 | 
         
            +
                        return_alpha=True
         
     | 
| 317 | 
         
            +
                    )  # (n_env, h, w, 3)
         
     | 
| 318 | 
         
            +
                    images = []
         
     | 
| 319 | 
         
            +
                    for key in self.bg_images:
         
     | 
| 320 | 
         
            +
                        if "render_camera" not in key:
         
     | 
| 321 | 
         
            +
                            continue
         
     | 
| 322 | 
         
            +
                        env_idx = int(key.split("-env")[-1])
         
     | 
| 323 | 
         
            +
                        rgba = alpha_blend_rgba(
         
     | 
| 324 | 
         
            +
                            fg_images[env_idx].cpu().numpy(), self.bg_images[key]
         
     | 
| 325 | 
         
            +
                        )
         
     | 
| 326 | 
         
            +
                        images.append(self.image_transform(rgba))
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                    images = torch.stack(images, dim=0)
         
     | 
| 329 | 
         
            +
                    images = images.permute(0, 2, 3, 1)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                    return images[..., :3]
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                def evaluate(self):
         
     | 
| 334 | 
         
            +
                    obj_to_goal_pos = (
         
     | 
| 335 | 
         
            +
                        self.obj.pose.p
         
     | 
| 336 | 
         
            +
                    )  # self.goal_site.pose.p - self.obj.pose.p
         
     | 
| 337 | 
         
            +
                    is_obj_placed = (
         
     | 
| 338 | 
         
            +
                        torch.linalg.norm(obj_to_goal_pos, axis=1) <= self.goal_thresh
         
     | 
| 339 | 
         
            +
                    )
         
     | 
| 340 | 
         
            +
                    is_grasped = self.agent.is_grasping(self.obj)
         
     | 
| 341 | 
         
            +
                    is_robot_static = self.agent.is_static(0.2)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    return dict(
         
     | 
| 344 | 
         
            +
                        is_grasped=is_grasped,
         
     | 
| 345 | 
         
            +
                        obj_to_goal_pos=obj_to_goal_pos,
         
     | 
| 346 | 
         
            +
                        is_obj_placed=is_obj_placed,
         
     | 
| 347 | 
         
            +
                        is_robot_static=is_robot_static,
         
     | 
| 348 | 
         
            +
                        is_grasping=self.agent.is_grasping(self.obj),
         
     | 
| 349 | 
         
            +
                        success=torch.logical_and(is_obj_placed, is_robot_static),
         
     | 
| 350 | 
         
            +
                    )
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                def _get_obs_extra(self, info: dict):
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    return dict()
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
         
     | 
| 357 | 
         
            +
                    tcp_to_obj_dist = torch.linalg.norm(
         
     | 
| 358 | 
         
            +
                        self.obj.pose.p - self.agent.tcp.pose.p, axis=1
         
     | 
| 359 | 
         
            +
                    )
         
     | 
| 360 | 
         
            +
                    reaching_reward = 1 - torch.tanh(5 * tcp_to_obj_dist)
         
     | 
| 361 | 
         
            +
                    reward = reaching_reward
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    is_grasped = info["is_grasped"]
         
     | 
| 364 | 
         
            +
                    reward += is_grasped
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                    # obj_to_goal_dist = torch.linalg.norm(
         
     | 
| 367 | 
         
            +
                    #     self.goal_site.pose.p - self.obj.pose.p, axis=1
         
     | 
| 368 | 
         
            +
                    # )
         
     | 
| 369 | 
         
            +
                    obj_to_goal_dist = torch.linalg.norm(
         
     | 
| 370 | 
         
            +
                        self.obj.pose.p - self.obj.pose.p, axis=1
         
     | 
| 371 | 
         
            +
                    )
         
     | 
| 372 | 
         
            +
                    place_reward = 1 - torch.tanh(5 * obj_to_goal_dist)
         
     | 
| 373 | 
         
            +
                    reward += place_reward * is_grasped
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                    reward += info["is_obj_placed"] * is_grasped
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    static_reward = 1 - torch.tanh(
         
     | 
| 378 | 
         
            +
                        5
         
     | 
| 379 | 
         
            +
                        * torch.linalg.norm(self.agent.robot.get_qvel()[..., :-2], axis=1)
         
     | 
| 380 | 
         
            +
                    )
         
     | 
| 381 | 
         
            +
                    reward += static_reward * info["is_obj_placed"] * is_grasped
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    reward[info["success"]] = 6
         
     | 
| 384 | 
         
            +
                    return reward
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                def compute_normalized_dense_reward(
         
     | 
| 387 | 
         
            +
                    self, obs: any, action: torch.Tensor, info: dict
         
     | 
| 388 | 
         
            +
                ):
         
     | 
| 389 | 
         
            +
                    return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
         
     | 
    	
        embodied_gen/models/gs_model.py
    CHANGED
    
    | 
         @@ -51,17 +51,15 @@ class RenderResult: 
     | 
|
| 51 | 
         | 
| 52 | 
         
             
                def __post_init__(self):
         
     | 
| 53 | 
         
             
                    if isinstance(self.rgb, torch.Tensor):
         
     | 
| 54 | 
         
            -
                        rgb = self.rgb 
     | 
| 55 | 
         
            -
                        rgb = ( 
     | 
| 56 | 
         
            -
                        self.rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
         
     | 
| 57 | 
         
             
                    if isinstance(self.depth, torch.Tensor):
         
     | 
| 58 | 
         
            -
                        self.depth = self.depth. 
     | 
| 59 | 
         
             
                    if isinstance(self.opacity, torch.Tensor):
         
     | 
| 60 | 
         
            -
                        opacity = self.opacity 
     | 
| 61 | 
         
            -
                        opacity = ( 
     | 
| 62 | 
         
            -
                        self.opacity = cv2.cvtColor(opacity, cv2.COLOR_GRAY2RGB)
         
     | 
| 63 | 
         
             
                        mask = np.where(self.opacity > self.mask_threshold, 255, 0)
         
     | 
| 64 | 
         
            -
                        self.mask = mask 
     | 
| 65 | 
         
             
                        self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
         
     | 
| 66 | 
         | 
| 67 | 
         | 
| 
         | 
|
| 51 | 
         | 
| 52 | 
         
             
                def __post_init__(self):
         
     | 
| 53 | 
         
             
                    if isinstance(self.rgb, torch.Tensor):
         
     | 
| 54 | 
         
            +
                        rgb = (self.rgb * 255).to(torch.uint8)
         
     | 
| 55 | 
         
            +
                        self.rgb = rgb.cpu().numpy()[..., ::-1]
         
     | 
| 
         | 
|
| 56 | 
         
             
                    if isinstance(self.depth, torch.Tensor):
         
     | 
| 57 | 
         
            +
                        self.depth = self.depth.cpu().numpy()
         
     | 
| 58 | 
         
             
                    if isinstance(self.opacity, torch.Tensor):
         
     | 
| 59 | 
         
            +
                        opacity = (self.opacity * 255).to(torch.uint8)
         
     | 
| 60 | 
         
            +
                        self.opacity = opacity.cpu().numpy()
         
     | 
| 
         | 
|
| 61 | 
         
             
                        mask = np.where(self.opacity > self.mask_threshold, 255, 0)
         
     | 
| 62 | 
         
            +
                        self.mask = mask.astype(np.uint8)
         
     | 
| 63 | 
         
             
                        self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
         
     | 
| 64 | 
         | 
| 65 | 
         | 
    	
        embodied_gen/models/layout.py
    ADDED
    
    | 
         @@ -0,0 +1,509 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Project EmbodiedGen
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 14 | 
         
            +
            # implied. See the License for the specific language governing
         
     | 
| 15 | 
         
            +
            # permissions and limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import argparse
         
     | 
| 19 | 
         
            +
            import json
         
     | 
| 20 | 
         
            +
            import logging
         
     | 
| 21 | 
         
            +
            import os
         
     | 
| 22 | 
         
            +
            import re
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            import json_repair
         
     | 
| 25 | 
         
            +
            from embodied_gen.utils.enum import (
         
     | 
| 26 | 
         
            +
                LayoutInfo,
         
     | 
| 27 | 
         
            +
                RobotItemEnum,
         
     | 
| 28 | 
         
            +
                Scene3DItemEnum,
         
     | 
| 29 | 
         
            +
                SpatialRelationEnum,
         
     | 
| 30 | 
         
            +
            )
         
     | 
| 31 | 
         
            +
            from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
         
     | 
| 32 | 
         
            +
            from embodied_gen.utils.process_media import SceneTreeVisualizer
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            logging.basicConfig(level=logging.INFO)
         
     | 
| 35 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            __all__ = [
         
     | 
| 39 | 
         
            +
                "LayoutDesigner",
         
     | 
| 40 | 
         
            +
                "LAYOUT_DISASSEMBLER",
         
     | 
| 41 | 
         
            +
                "LAYOUT_GRAPHER",
         
     | 
| 42 | 
         
            +
                "LAYOUT_DESCRIBER",
         
     | 
| 43 | 
         
            +
            ]
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            DISTRACTOR_NUM = 3  # Maximum number of distractor objects allowed
         
     | 
| 47 | 
         
            +
            LAYOUT_DISASSEMBLE_PROMPT = f"""
         
     | 
| 48 | 
         
            +
                You are an intelligent 3D scene planner. Given a natural language
         
     | 
| 49 | 
         
            +
                description of a robotic task, output a structured description of
         
     | 
| 50 | 
         
            +
                an interactive 3D scene.
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                The output must include the following fields:
         
     | 
| 53 | 
         
            +
                - task: A high-level task type (e.g., "single-arm pick",
         
     | 
| 54 | 
         
            +
                    "dual-arm grasping", "pick and place", "object sorting").
         
     | 
| 55 | 
         
            +
                - {Scene3DItemEnum.ROBOT}: The name or type of robot involved. If not mentioned,
         
     | 
| 56 | 
         
            +
                    use {RobotItemEnum.FRANKA} as default.
         
     | 
| 57 | 
         
            +
                - {Scene3DItemEnum.BACKGROUND}: The room or indoor environment where the task happens
         
     | 
| 58 | 
         
            +
                    (e.g., Kitchen, Bedroom, Living Room, Workshop, Office).
         
     | 
| 59 | 
         
            +
                - {Scene3DItemEnum.CONTEXT}: A indoor object involved in the manipulation
         
     | 
| 60 | 
         
            +
                    (e.g., Table, Shelf, Desk, Bed, Cabinet).
         
     | 
| 61 | 
         
            +
                - {Scene3DItemEnum.MANIPULATED_OBJS}: The main object(s) that the robot directly interacts with.
         
     | 
| 62 | 
         
            +
                - {Scene3DItemEnum.DISTRACTOR_OBJS}: Other objects that naturally belong to the scene but are not part of the main task.
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                Constraints:
         
     | 
| 65 | 
         
            +
                - The {Scene3DItemEnum.BACKGROUND} must logically match the described task.
         
     | 
| 66 | 
         
            +
                - The {Scene3DItemEnum.CONTEXT} must fit within the {Scene3DItemEnum.BACKGROUND}. (e.g., a bedroom may include a table or bed, but not a workbench.)
         
     | 
| 67 | 
         
            +
                - The {Scene3DItemEnum.CONTEXT} must be a concrete indoor object, such as a "table",
         
     | 
| 68 | 
         
            +
                    "shelf", "desk", or "bed". It must not be an abstract concept (e.g., "area", "space", "zone")
         
     | 
| 69 | 
         
            +
                    or structural surface (e.g., "floor", "ground"). If the input describes an interaction near
         
     | 
| 70 | 
         
            +
                    the floor or vague space, you must infer a plausible object like a "table", "cabinet", or "storage box" instead.
         
     | 
| 71 | 
         
            +
                - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} objects must be plausible,
         
     | 
| 72 | 
         
            +
                    and semantically compatible with the {Scene3DItemEnum.CONTEXT} and {Scene3DItemEnum.BACKGROUND}.
         
     | 
| 73 | 
         
            +
                - {Scene3DItemEnum.DISTRACTOR_OBJS} must not confuse or overlap with the manipulated objects.
         
     | 
| 74 | 
         
            +
                - {Scene3DItemEnum.DISTRACTOR_OBJS} number limit: {DISTRACTOR_NUM} distractors maximum.
         
     | 
| 75 | 
         
            +
                - All {Scene3DItemEnum.BACKGROUND} are limited to indoor environments.
         
     | 
| 76 | 
         
            +
                - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} are rigid bodies and not include flexible objects.
         
     | 
| 77 | 
         
            +
                - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} must be common
         
     | 
| 78 | 
         
            +
                    household or office items or furniture, not abstract concepts, not too small like needle.
         
     | 
| 79 | 
         
            +
                - If the input includes a plural or grouped object (e.g., "pens", "bottles", "plates", "fruit"),
         
     | 
| 80 | 
         
            +
                    you must decompose it into multiple individual instances (e.g., ["pen", "pen"], ["apple", "pear"]).
         
     | 
| 81 | 
         
            +
                - Containers that hold objects (e.g., "bowl of apples", "box of tools") must
         
     | 
| 82 | 
         
            +
                    be separated into individual items (e.g., ["bowl", "apple", "apple"]).
         
     | 
| 83 | 
         
            +
                - Do not include transparent objects such as "glass", "plastic", etc.
         
     | 
| 84 | 
         
            +
                - The output must be in compact JSON format and use Markdown syntax, just like the output in the example below.
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                Examples:
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                Input:
         
     | 
| 89 | 
         
            +
                "Pick up the marker from the table and put it in the bowl robot {RobotItemEnum.UR5}."
         
     | 
| 90 | 
         
            +
                Output:
         
     | 
| 91 | 
         
            +
                ```json
         
     | 
| 92 | 
         
            +
                {{
         
     | 
| 93 | 
         
            +
                    "task_desc": "Pick up the marker from the table and put it in the bowl.",
         
     | 
| 94 | 
         
            +
                    "task": "pick and place",
         
     | 
| 95 | 
         
            +
                    "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.UR5}",
         
     | 
| 96 | 
         
            +
                    "{Scene3DItemEnum.BACKGROUND}": "kitchen",
         
     | 
| 97 | 
         
            +
                    "{Scene3DItemEnum.CONTEXT}": "table",
         
     | 
| 98 | 
         
            +
                    "{Scene3DItemEnum.MANIPULATED_OBJS}": ["marker"],
         
     | 
| 99 | 
         
            +
                    "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["mug", "notebook", "bowl"]
         
     | 
| 100 | 
         
            +
                }}
         
     | 
| 101 | 
         
            +
                ```
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                Input:
         
     | 
| 104 | 
         
            +
                "Put the rubik's cube on the top of the shelf."
         
     | 
| 105 | 
         
            +
                Output:
         
     | 
| 106 | 
         
            +
                ```json
         
     | 
| 107 | 
         
            +
                {{
         
     | 
| 108 | 
         
            +
                    "task_desc": "Put the rubik's cube on the top of the shelf.",
         
     | 
| 109 | 
         
            +
                    "task": "pick and place",
         
     | 
| 110 | 
         
            +
                    "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.FRANKA}",
         
     | 
| 111 | 
         
            +
                    "{Scene3DItemEnum.BACKGROUND}": "bedroom",
         
     | 
| 112 | 
         
            +
                    "{Scene3DItemEnum.CONTEXT}": "shelf",
         
     | 
| 113 | 
         
            +
                    "{Scene3DItemEnum.MANIPULATED_OBJS}": ["rubik's cube"],
         
     | 
| 114 | 
         
            +
                    "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["pen", "cup", "toy car"]
         
     | 
| 115 | 
         
            +
                }}
         
     | 
| 116 | 
         
            +
                ```
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                Input:
         
     | 
| 119 | 
         
            +
                "Remove all the objects from the white basket and put them on the table."
         
     | 
| 120 | 
         
            +
                Output:
         
     | 
| 121 | 
         
            +
                ```json
         
     | 
| 122 | 
         
            +
                {{
         
     | 
| 123 | 
         
            +
                    "task_desc": "Remove all the objects from the white basket and put them on the table, robot {RobotItemEnum.PIPER}.",
         
     | 
| 124 | 
         
            +
                    "task": "pick and place",
         
     | 
| 125 | 
         
            +
                    "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.PIPER}",
         
     | 
| 126 | 
         
            +
                    "{Scene3DItemEnum.BACKGROUND}": "office",
         
     | 
| 127 | 
         
            +
                    "{Scene3DItemEnum.CONTEXT}": "table",
         
     | 
| 128 | 
         
            +
                    "{Scene3DItemEnum.MANIPULATED_OBJS}": ["banana", "mobile phone"],
         
     | 
| 129 | 
         
            +
                    "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["plate", "white basket"]
         
     | 
| 130 | 
         
            +
                }}
         
     | 
| 131 | 
         
            +
                ```
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                Input:
         
     | 
| 134 | 
         
            +
                "Pick up the rope on the chair and put it in the box."
         
     | 
| 135 | 
         
            +
                Output:
         
     | 
| 136 | 
         
            +
                ```json
         
     | 
| 137 | 
         
            +
                {{
         
     | 
| 138 | 
         
            +
                    "task_desc": "Pick up the rope on the chair and put it in the box, robot {RobotItemEnum.FRANKA}.",
         
     | 
| 139 | 
         
            +
                    "task": "pick and place",
         
     | 
| 140 | 
         
            +
                    "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.FRANKA}",
         
     | 
| 141 | 
         
            +
                    "{Scene3DItemEnum.BACKGROUND}": "living room",
         
     | 
| 142 | 
         
            +
                    "{Scene3DItemEnum.CONTEXT}": "chair",
         
     | 
| 143 | 
         
            +
                    "{Scene3DItemEnum.MANIPULATED_OBJS}": ["rope", "box"],
         
     | 
| 144 | 
         
            +
                    "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["magazine"]
         
     | 
| 145 | 
         
            +
                }}
         
     | 
| 146 | 
         
            +
                ```
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                Input:
         
     | 
| 149 | 
         
            +
                "Pick up the seal tape and plastic from the counter and put them in the open drawer and close it."
         
     | 
| 150 | 
         
            +
                Output:
         
     | 
| 151 | 
         
            +
                ```json
         
     | 
| 152 | 
         
            +
                {{
         
     | 
| 153 | 
         
            +
                    "task_desc": "Pick up the seal tape and plastic from the counter and put them in the open drawer and close it.",
         
     | 
| 154 | 
         
            +
                    "task": "pick and place",
         
     | 
| 155 | 
         
            +
                    "robot": "franka",
         
     | 
| 156 | 
         
            +
                    "background": "kitchen",
         
     | 
| 157 | 
         
            +
                    "context": "counter",
         
     | 
| 158 | 
         
            +
                    "manipulated_objs": ["seal tape", "plastic", "opened drawer"],
         
     | 
| 159 | 
         
            +
                    "distractor_objs": ["scissors"]
         
     | 
| 160 | 
         
            +
                }}
         
     | 
| 161 | 
         
            +
                ```
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                Input:
         
     | 
| 164 | 
         
            +
                "Put the pens in the grey bowl."
         
     | 
| 165 | 
         
            +
                Output:
         
     | 
| 166 | 
         
            +
                ```json
         
     | 
| 167 | 
         
            +
                {{
         
     | 
| 168 | 
         
            +
                    "task_desc": "Put the pens in the grey bowl.",
         
     | 
| 169 | 
         
            +
                    "task": "pick and place",
         
     | 
| 170 | 
         
            +
                    "robot": "franka",
         
     | 
| 171 | 
         
            +
                    "background": "office",
         
     | 
| 172 | 
         
            +
                    "context": "table",
         
     | 
| 173 | 
         
            +
                    "manipulated_objs": ["pen", "pen", "grey bowl"],
         
     | 
| 174 | 
         
            +
                    "distractor_objs": ["notepad", "cup"]
         
     | 
| 175 | 
         
            +
                }}
         
     | 
| 176 | 
         
            +
                ```
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            """
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
            LAYOUT_HIERARCHY_PROMPT = f"""
         
     | 
| 182 | 
         
            +
                You are a 3D scene layout reasoning expert.
         
     | 
| 183 | 
         
            +
                Your task is to generate a spatial relationship dictionary in multiway tree
         
     | 
| 184 | 
         
            +
                that describes how objects are arranged in a 3D environment
         
     | 
| 185 | 
         
            +
                based on a given task description and object list.
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                Input in JSON format containing the task description, task type,
         
     | 
| 188 | 
         
            +
                {Scene3DItemEnum.ROBOT}, {Scene3DItemEnum.BACKGROUND}, {Scene3DItemEnum.CONTEXT},
         
     | 
| 189 | 
         
            +
                and a list of objects, including {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS}.
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                ### Supported Spatial Relations:
         
     | 
| 192 | 
         
            +
                - "{SpatialRelationEnum.ON}": The child object bottom is directly on top of the parent object top.
         
     | 
| 193 | 
         
            +
                - "{SpatialRelationEnum.INSIDE}": The child object is inside the context object.
         
     | 
| 194 | 
         
            +
                - "{SpatialRelationEnum.IN}": The {Scene3DItemEnum.ROBOT} in the {Scene3DItemEnum.BACKGROUND}.
         
     | 
| 195 | 
         
            +
                - "{SpatialRelationEnum.FLOOR}": The child object bottom is on the floor of the {Scene3DItemEnum.BACKGROUND}.
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                ### Rules:
         
     | 
| 198 | 
         
            +
                - The {Scene3DItemEnum.CONTEXT} object must be "{SpatialRelationEnum.FLOOR}" the {Scene3DItemEnum.BACKGROUND}.
         
     | 
| 199 | 
         
            +
                - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} must be either
         
     | 
| 200 | 
         
            +
                    "{SpatialRelationEnum.ON}" or "{SpatialRelationEnum.INSIDE}" the {Scene3DItemEnum.CONTEXT}
         
     | 
| 201 | 
         
            +
                - Or "{SpatialRelationEnum.FLOOR}" {Scene3DItemEnum.BACKGROUND}.
         
     | 
| 202 | 
         
            +
                - Use "{SpatialRelationEnum.INSIDE}" only if the parent is a container-like object (e.g., shelf, rack, cabinet).
         
     | 
| 203 | 
         
            +
                - Do not define relationship edges between objects, only for the child and parent nodes.
         
     | 
| 204 | 
         
            +
                - {Scene3DItemEnum.ROBOT} must "{SpatialRelationEnum.IN}" the {Scene3DItemEnum.BACKGROUND}.
         
     | 
| 205 | 
         
            +
                - Ensure that each object appears only once in the layout tree, and its spatial relationship is defined with only one parent.
         
     | 
| 206 | 
         
            +
                - Ensure a valid multiway tree structure with a maximum depth of 2 levels suitable for a 3D scene layout representation.
         
     | 
| 207 | 
         
            +
                - Only output the final output in JSON format, using Markdown syntax as in examples.
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                ### Example
         
     | 
| 210 | 
         
            +
                Input:
         
     | 
| 211 | 
         
            +
                {{
         
     | 
| 212 | 
         
            +
                    "task_desc": "Pick up the marker from the table and put it in the bowl.",
         
     | 
| 213 | 
         
            +
                    "task": "pick and place",
         
     | 
| 214 | 
         
            +
                    "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.FRANKA}",
         
     | 
| 215 | 
         
            +
                    "{Scene3DItemEnum.BACKGROUND}": "kitchen",
         
     | 
| 216 | 
         
            +
                    "{Scene3DItemEnum.CONTEXT}": "table",
         
     | 
| 217 | 
         
            +
                    "{Scene3DItemEnum.MANIPULATED_OBJS}": ["marker", "bowl"],
         
     | 
| 218 | 
         
            +
                    "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["mug", "chair"]
         
     | 
| 219 | 
         
            +
                }}
         
     | 
| 220 | 
         
            +
                Intermediate Think:
         
     | 
| 221 | 
         
            +
                    table {SpatialRelationEnum.FLOOR} kitchen
         
     | 
| 222 | 
         
            +
                    chair {SpatialRelationEnum.FLOOR} kitchen
         
     | 
| 223 | 
         
            +
                    {RobotItemEnum.FRANKA} {SpatialRelationEnum.IN} kitchen
         
     | 
| 224 | 
         
            +
                    marker {SpatialRelationEnum.ON} table
         
     | 
| 225 | 
         
            +
                    bowl {SpatialRelationEnum.ON} table
         
     | 
| 226 | 
         
            +
                    mug {SpatialRelationEnum.ON} table
         
     | 
| 227 | 
         
            +
                Final Output:
         
     | 
| 228 | 
         
            +
                ```json
         
     | 
| 229 | 
         
            +
                {{
         
     | 
| 230 | 
         
            +
                    "kitchen": [
         
     | 
| 231 | 
         
            +
                        ["table", "{SpatialRelationEnum.FLOOR}"],
         
     | 
| 232 | 
         
            +
                        ["chair", "{SpatialRelationEnum.FLOOR}"],
         
     | 
| 233 | 
         
            +
                        ["{RobotItemEnum.FRANKA}", "{SpatialRelationEnum.IN}"]
         
     | 
| 234 | 
         
            +
                    ],
         
     | 
| 235 | 
         
            +
                    "table": [
         
     | 
| 236 | 
         
            +
                        ["marker", "{SpatialRelationEnum.ON}"],
         
     | 
| 237 | 
         
            +
                        ["bowl", "{SpatialRelationEnum.ON}"],
         
     | 
| 238 | 
         
            +
                        ["mug", "{SpatialRelationEnum.ON}"]
         
     | 
| 239 | 
         
            +
                    ]
         
     | 
| 240 | 
         
            +
                }}
         
     | 
| 241 | 
         
            +
                ```
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                Input:
         
     | 
| 244 | 
         
            +
                {{
         
     | 
| 245 | 
         
            +
                    "task_desc": "Put the marker on top of the book.",
         
     | 
| 246 | 
         
            +
                    "task": "pick and place",
         
     | 
| 247 | 
         
            +
                    "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.UR5}",
         
     | 
| 248 | 
         
            +
                    "{Scene3DItemEnum.BACKGROUND}": "office",
         
     | 
| 249 | 
         
            +
                    "{Scene3DItemEnum.CONTEXT}": "desk",
         
     | 
| 250 | 
         
            +
                    "{Scene3DItemEnum.MANIPULATED_OBJS}": ["marker", "book"],
         
     | 
| 251 | 
         
            +
                    "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["pen holder", "notepad"]
         
     | 
| 252 | 
         
            +
                }}
         
     | 
| 253 | 
         
            +
                Intermediate Think:
         
     | 
| 254 | 
         
            +
                    desk {SpatialRelationEnum.FLOOR} office
         
     | 
| 255 | 
         
            +
                    {RobotItemEnum.UR5} {SpatialRelationEnum.IN} office
         
     | 
| 256 | 
         
            +
                    marker {SpatialRelationEnum.ON} desk
         
     | 
| 257 | 
         
            +
                    book {SpatialRelationEnum.ON} desk
         
     | 
| 258 | 
         
            +
                    pen holder {SpatialRelationEnum.ON} desk
         
     | 
| 259 | 
         
            +
                    notepad {SpatialRelationEnum.ON} desk
         
     | 
| 260 | 
         
            +
                Final Output:
         
     | 
| 261 | 
         
            +
                ```json
         
     | 
| 262 | 
         
            +
                {{
         
     | 
| 263 | 
         
            +
                    "office": [
         
     | 
| 264 | 
         
            +
                        ["desk", "{SpatialRelationEnum.FLOOR}"],
         
     | 
| 265 | 
         
            +
                        ["{RobotItemEnum.UR5}", "{SpatialRelationEnum.IN}"]
         
     | 
| 266 | 
         
            +
                    ],
         
     | 
| 267 | 
         
            +
                    "desk": [
         
     | 
| 268 | 
         
            +
                        ["marker", "{SpatialRelationEnum.ON}"],
         
     | 
| 269 | 
         
            +
                        ["book", "{SpatialRelationEnum.ON}"],
         
     | 
| 270 | 
         
            +
                        ["pen holder", "{SpatialRelationEnum.ON}"],
         
     | 
| 271 | 
         
            +
                        ["notepad", "{SpatialRelationEnum.ON}"]
         
     | 
| 272 | 
         
            +
                    ]
         
     | 
| 273 | 
         
            +
                }}
         
     | 
| 274 | 
         
            +
                ```
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                Input:
         
     | 
| 277 | 
         
            +
                {{
         
     | 
| 278 | 
         
            +
                    "task_desc": "Put the rubik's cube on the top of the shelf.",
         
     | 
| 279 | 
         
            +
                    "task": "pick and place",
         
     | 
| 280 | 
         
            +
                    "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.UR5}",
         
     | 
| 281 | 
         
            +
                    "{Scene3DItemEnum.BACKGROUND}": "bedroom",
         
     | 
| 282 | 
         
            +
                    "{Scene3DItemEnum.CONTEXT}": "shelf",
         
     | 
| 283 | 
         
            +
                    "{Scene3DItemEnum.MANIPULATED_OBJS}": ["rubik's cube"],
         
     | 
| 284 | 
         
            +
                    "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["toy car", "pen"]
         
     | 
| 285 | 
         
            +
                }}
         
     | 
| 286 | 
         
            +
                Intermediate Think:
         
     | 
| 287 | 
         
            +
                    shelf {SpatialRelationEnum.FLOOR} bedroom
         
     | 
| 288 | 
         
            +
                    {RobotItemEnum.UR5} {SpatialRelationEnum.IN} bedroom
         
     | 
| 289 | 
         
            +
                    rubik's cube {SpatialRelationEnum.INSIDE} shelf
         
     | 
| 290 | 
         
            +
                    toy car {SpatialRelationEnum.INSIDE} shelf
         
     | 
| 291 | 
         
            +
                    pen {SpatialRelationEnum.INSIDE} shelf
         
     | 
| 292 | 
         
            +
                Final Output:
         
     | 
| 293 | 
         
            +
                ```json
         
     | 
| 294 | 
         
            +
                {{
         
     | 
| 295 | 
         
            +
                    "bedroom": [
         
     | 
| 296 | 
         
            +
                        ["shelf", "{SpatialRelationEnum.FLOOR}"],
         
     | 
| 297 | 
         
            +
                        ["{RobotItemEnum.UR5}", "{SpatialRelationEnum.IN}"]
         
     | 
| 298 | 
         
            +
                    ],
         
     | 
| 299 | 
         
            +
                    "shelf": [
         
     | 
| 300 | 
         
            +
                        ["rubik's cube", "{SpatialRelationEnum.INSIDE}"],
         
     | 
| 301 | 
         
            +
                        ["toy car", "{SpatialRelationEnum.INSIDE}"],
         
     | 
| 302 | 
         
            +
                        ["pen", "{SpatialRelationEnum.INSIDE}"]
         
     | 
| 303 | 
         
            +
                    ]
         
     | 
| 304 | 
         
            +
                }}
         
     | 
| 305 | 
         
            +
                ```
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                Input:
         
     | 
| 308 | 
         
            +
                {{
         
     | 
| 309 | 
         
            +
                    "task_desc": "Put the marker in the cup on the counter.",
         
     | 
| 310 | 
         
            +
                    "task": "pick and place",
         
     | 
| 311 | 
         
            +
                    "robot": "franka",
         
     | 
| 312 | 
         
            +
                    "background": "kitchen",
         
     | 
| 313 | 
         
            +
                    "context": "counter",
         
     | 
| 314 | 
         
            +
                    "manipulated_objs": ["marker", "cup"],
         
     | 
| 315 | 
         
            +
                    "distractor_objs": ["plate", "spoon"]
         
     | 
| 316 | 
         
            +
                }}
         
     | 
| 317 | 
         
            +
                Intermediate Think:
         
     | 
| 318 | 
         
            +
                    counter {SpatialRelationEnum.FLOOR} kitchen
         
     | 
| 319 | 
         
            +
                    {RobotItemEnum.FRANKA} {SpatialRelationEnum.IN} kitchen
         
     | 
| 320 | 
         
            +
                    marker {SpatialRelationEnum.ON} counter
         
     | 
| 321 | 
         
            +
                    cup {SpatialRelationEnum.ON} counter
         
     | 
| 322 | 
         
            +
                    plate {SpatialRelationEnum.ON} counter
         
     | 
| 323 | 
         
            +
                    spoon {SpatialRelationEnum.ON} counter
         
     | 
| 324 | 
         
            +
                Final Output:
         
     | 
| 325 | 
         
            +
                ```json
         
     | 
| 326 | 
         
            +
                {{
         
     | 
| 327 | 
         
            +
                    "kitchen": [
         
     | 
| 328 | 
         
            +
                        ["counter", "{SpatialRelationEnum.FLOOR}"],
         
     | 
| 329 | 
         
            +
                        ["{RobotItemEnum.FRANKA}", "{SpatialRelationEnum.IN}"]
         
     | 
| 330 | 
         
            +
                    ],
         
     | 
| 331 | 
         
            +
                    "counter": [
         
     | 
| 332 | 
         
            +
                        ["marker", "{SpatialRelationEnum.ON}"],
         
     | 
| 333 | 
         
            +
                        ["cup", "{SpatialRelationEnum.ON}"],
         
     | 
| 334 | 
         
            +
                        ["plate", "{SpatialRelationEnum.ON}"],
         
     | 
| 335 | 
         
            +
                        ["spoon", "{SpatialRelationEnum.ON}"]
         
     | 
| 336 | 
         
            +
                    ]
         
     | 
| 337 | 
         
            +
                }}
         
     | 
| 338 | 
         
            +
                ```
         
     | 
| 339 | 
         
            +
            """
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
            LAYOUT_DESCRIBER_PROMPT = """
         
     | 
| 343 | 
         
            +
                You are a 3D asset style descriptor.
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                Given a task description and a dictionary where the key is the object content and
         
     | 
| 346 | 
         
            +
                the value is the object type, output a JSON dictionary with each object paired
         
     | 
| 347 | 
         
            +
                with a concise, styled visual description suitable for 3D asset generation.
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                Generation Guidelines:
         
     | 
| 350 | 
         
            +
                - For each object, brainstorm multiple style candidates before selecting the final
         
     | 
| 351 | 
         
            +
                    description. Vary phrasing, material, texture, color, and spatial details.
         
     | 
| 352 | 
         
            +
                - Each description must be a maximum of 15 words, including color, style, materials.
         
     | 
| 353 | 
         
            +
                - Descriptions should be visually grounded, specific, and reflect surface texture and structure.
         
     | 
| 354 | 
         
            +
                - For objects marked as "context", explicitly mention the object is standalone, has an empty top.
         
     | 
| 355 | 
         
            +
                - Use rich style descriptors: e.g., "scratched brown wooden desk" etc.
         
     | 
| 356 | 
         
            +
                - Ensure all object styles align with the task's overall context and environment.
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                Format your output in JSON like the example below.
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                Example Input:
         
     | 
| 361 | 
         
            +
                "Pick up the rope on the chair and put it in the box. {'living room': 'background', 'chair': 'context',
         
     | 
| 362 | 
         
            +
                    'rope': 'manipulated_objs', 'box': 'manipulated_objs', 'magazine': 'distractor_objs'}"
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                Example Output:
         
     | 
| 365 | 
         
            +
                ```json
         
     | 
| 366 | 
         
            +
                {
         
     | 
| 367 | 
         
            +
                    "living room": "modern cozy living room with soft sunlight and light grey carpet",
         
     | 
| 368 | 
         
            +
                    "chair": "standalone dark oak chair with no surroundings and clean empty seat",
         
     | 
| 369 | 
         
            +
                    "rope": "twisted hemp rope with rough fibers and dusty beige texture",
         
     | 
| 370 | 
         
            +
                    "box": "slightly crumpled cardboard box with open flaps and brown textured surface",
         
     | 
| 371 | 
         
            +
                    "magazine": "celebrity magazine with glossy red cover and large bold title"
         
     | 
| 372 | 
         
            +
                }
         
     | 
| 373 | 
         
            +
                ```
         
     | 
| 374 | 
         
            +
            """
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
            class LayoutDesigner(object):
         
     | 
| 378 | 
         
            +
                def __init__(
         
     | 
| 379 | 
         
            +
                    self,
         
     | 
| 380 | 
         
            +
                    gpt_client: GPTclient,
         
     | 
| 381 | 
         
            +
                    system_prompt: str,
         
     | 
| 382 | 
         
            +
                    verbose: bool = False,
         
     | 
| 383 | 
         
            +
                ) -> None:
         
     | 
| 384 | 
         
            +
                    self.prompt = system_prompt.strip()
         
     | 
| 385 | 
         
            +
                    self.verbose = verbose
         
     | 
| 386 | 
         
            +
                    self.gpt_client = gpt_client
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                def query(self, prompt: str, params: dict = None) -> str:
         
     | 
| 389 | 
         
            +
                    full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                    response = self.gpt_client.query(
         
     | 
| 392 | 
         
            +
                        text_prompt=full_prompt,
         
     | 
| 393 | 
         
            +
                        params=params,
         
     | 
| 394 | 
         
            +
                    )
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                    if self.verbose:
         
     | 
| 397 | 
         
            +
                        logger.info(f"Response: {response}")
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    return response
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                def format_response(self, response: str) -> dict:
         
     | 
| 402 | 
         
            +
                    cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
         
     | 
| 403 | 
         
            +
                    try:
         
     | 
| 404 | 
         
            +
                        output = json.loads(cleaned)
         
     | 
| 405 | 
         
            +
                    except json.JSONDecodeError as e:
         
     | 
| 406 | 
         
            +
                        raise json.JSONDecodeError(
         
     | 
| 407 | 
         
            +
                            f"Error: {e}, failed to parse JSON response: {response}"
         
     | 
| 408 | 
         
            +
                        )
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                    return output
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
                def format_response_repair(self, response: str) -> dict:
         
     | 
| 413 | 
         
            +
                    return json_repair.loads(response)
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                def save_output(self, output: dict, save_path: str) -> None:
         
     | 
| 416 | 
         
            +
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)
         
     | 
| 417 | 
         
            +
                    with open(save_path, 'w') as f:
         
     | 
| 418 | 
         
            +
                        json.dump(output, f, indent=4)
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                def __call__(
         
     | 
| 421 | 
         
            +
                    self, prompt: str, save_path: str = None, params: dict = None
         
     | 
| 422 | 
         
            +
                ) -> dict | str:
         
     | 
| 423 | 
         
            +
                    response = self.query(prompt, params=params)
         
     | 
| 424 | 
         
            +
                    output = self.format_response_repair(response)
         
     | 
| 425 | 
         
            +
                    self.save_output(output, save_path) if save_path else None
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                    return output
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
            LAYOUT_DISASSEMBLER = LayoutDesigner(
         
     | 
| 431 | 
         
            +
                gpt_client=GPT_CLIENT, system_prompt=LAYOUT_DISASSEMBLE_PROMPT
         
     | 
| 432 | 
         
            +
            )
         
     | 
| 433 | 
         
            +
            LAYOUT_GRAPHER = LayoutDesigner(
         
     | 
| 434 | 
         
            +
                gpt_client=GPT_CLIENT, system_prompt=LAYOUT_HIERARCHY_PROMPT
         
     | 
| 435 | 
         
            +
            )
         
     | 
| 436 | 
         
            +
            LAYOUT_DESCRIBER = LayoutDesigner(
         
     | 
| 437 | 
         
            +
                gpt_client=GPT_CLIENT, system_prompt=LAYOUT_DESCRIBER_PROMPT
         
     | 
| 438 | 
         
            +
            )
         
     | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
            def build_scene_layout(
         
     | 
| 442 | 
         
            +
                task_desc: str, output_path: str = None, gpt_params: dict = None
         
     | 
| 443 | 
         
            +
            ) -> LayoutInfo:
         
     | 
| 444 | 
         
            +
                layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
         
     | 
| 445 | 
         
            +
                layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
         
     | 
| 446 | 
         
            +
                object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
         
     | 
| 447 | 
         
            +
                obj_prompt = f'{layout_relation["task_desc"]} {object_mapping}'
         
     | 
| 448 | 
         
            +
                objs_desc = LAYOUT_DESCRIBER(obj_prompt, params=gpt_params)
         
     | 
| 449 | 
         
            +
                layout_info = LayoutInfo(
         
     | 
| 450 | 
         
            +
                    layout_tree, layout_relation, objs_desc, object_mapping
         
     | 
| 451 | 
         
            +
                )
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                if output_path is not None:
         
     | 
| 454 | 
         
            +
                    visualizer = SceneTreeVisualizer(layout_info)
         
     | 
| 455 | 
         
            +
                    visualizer.render(save_path=output_path)
         
     | 
| 456 | 
         
            +
                    logger.info(f"Scene hierarchy tree saved to {output_path}")
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                return layout_info
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
            def parse_args():
         
     | 
| 462 | 
         
            +
                parser = argparse.ArgumentParser(description="3D Scene Layout Designer")
         
     | 
| 463 | 
         
            +
                parser.add_argument(
         
     | 
| 464 | 
         
            +
                    "--task_desc",
         
     | 
| 465 | 
         
            +
                    type=str,
         
     | 
| 466 | 
         
            +
                    default="Put the apples on the table on the plate",
         
     | 
| 467 | 
         
            +
                    help="Natural language description of the robotic task",
         
     | 
| 468 | 
         
            +
                )
         
     | 
| 469 | 
         
            +
                parser.add_argument(
         
     | 
| 470 | 
         
            +
                    "--save_root",
         
     | 
| 471 | 
         
            +
                    type=str,
         
     | 
| 472 | 
         
            +
                    default="outputs/layout_tree",
         
     | 
| 473 | 
         
            +
                    help="Path to save the layout output",
         
     | 
| 474 | 
         
            +
                )
         
     | 
| 475 | 
         
            +
                return parser.parse_args()
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 479 | 
         
            +
                from embodied_gen.utils.enum import LayoutInfo
         
     | 
| 480 | 
         
            +
                from embodied_gen.utils.process_media import SceneTreeVisualizer
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                args = parse_args()
         
     | 
| 483 | 
         
            +
                params = {
         
     | 
| 484 | 
         
            +
                    "temperature": 1.0,
         
     | 
| 485 | 
         
            +
                    "top_p": 0.95,
         
     | 
| 486 | 
         
            +
                    "frequency_penalty": 0.3,
         
     | 
| 487 | 
         
            +
                    "presence_penalty": 0.5,
         
     | 
| 488 | 
         
            +
                }
         
     | 
| 489 | 
         
            +
                layout_relation = LAYOUT_DISASSEMBLER(args.task_desc, params=params)
         
     | 
| 490 | 
         
            +
                layout_tree = LAYOUT_GRAPHER(layout_relation, params=params)
         
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
         
     | 
| 493 | 
         
            +
                obj_prompt = f'{layout_relation["task_desc"]} {object_mapping}'
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                objs_desc = LAYOUT_DESCRIBER(obj_prompt, params=params)
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
                layout_info = LayoutInfo(layout_tree, layout_relation, objs_desc)
         
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
                visualizer = SceneTreeVisualizer(layout_info)
         
     | 
| 500 | 
         
            +
                os.makedirs(args.save_root, exist_ok=True)
         
     | 
| 501 | 
         
            +
                scene_graph_path = f"{args.save_root}/scene_tree.jpg"
         
     | 
| 502 | 
         
            +
                visualizer.render(save_path=scene_graph_path)
         
     | 
| 503 | 
         
            +
                with open(f"{args.save_root}/layout.json", "w") as f:
         
     | 
| 504 | 
         
            +
                    json.dump(layout_info.to_dict(), f, indent=4)
         
     | 
| 505 | 
         
            +
             
     | 
| 506 | 
         
            +
                print(f"Scene hierarchy tree saved to {scene_graph_path}")
         
     | 
| 507 | 
         
            +
                print(f"Disassembled Layout: {layout_relation}")
         
     | 
| 508 | 
         
            +
                print(f"Layout Graph: {layout_tree}")
         
     | 
| 509 | 
         
            +
                print(f"Layout Descriptions: {objs_desc}")
         
     | 
    	
        embodied_gen/models/sr_model.py
    CHANGED
    
    | 
         @@ -53,7 +53,7 @@ class ImageStableSR: 
     | 
|
| 53 | 
         
             
                        torch_dtype=torch.float16,
         
     | 
| 54 | 
         
             
                    ).to(device)
         
     | 
| 55 | 
         
             
                    self.up_pipeline_x4.set_progress_bar_config(disable=True)
         
     | 
| 56 | 
         
            -
                     
     | 
| 57 | 
         | 
| 58 | 
         
             
                @spaces.GPU
         
     | 
| 59 | 
         
             
                def __call__(
         
     | 
| 
         | 
|
| 53 | 
         
             
                        torch_dtype=torch.float16,
         
     | 
| 54 | 
         
             
                    ).to(device)
         
     | 
| 55 | 
         
             
                    self.up_pipeline_x4.set_progress_bar_config(disable=True)
         
     | 
| 56 | 
         
            +
                    self.up_pipeline_x4.enable_model_cpu_offload()
         
     | 
| 57 | 
         | 
| 58 | 
         
             
                @spaces.GPU
         
     | 
| 59 | 
         
             
                def __call__(
         
     | 
    	
        embodied_gen/models/text_model.py
    CHANGED
    
    | 
         @@ -135,7 +135,7 @@ def build_text2img_ip_pipeline( 
     | 
|
| 135 | 
         | 
| 136 | 
         
             
                pipe = pipe.to(device)
         
     | 
| 137 | 
         
             
                pipe.image_encoder = pipe.image_encoder.to(device)
         
     | 
| 138 | 
         
            -
                 
     | 
| 139 | 
         
             
                # pipe.enable_xformers_memory_efficient_attention()
         
     | 
| 140 | 
         
             
                # pipe.enable_vae_slicing()
         
     | 
| 141 | 
         | 
| 
         @@ -168,8 +168,8 @@ def build_text2img_pipeline( 
     | 
|
| 168 | 
         
             
                    force_zeros_for_empty_prompt=False,
         
     | 
| 169 | 
         
             
                )
         
     | 
| 170 | 
         
             
                pipe = pipe.to(device)
         
     | 
| 171 | 
         
            -
                 
     | 
| 172 | 
         
            -
                 
     | 
| 173 | 
         | 
| 174 | 
         
             
                return pipe
         
     | 
| 175 | 
         | 
| 
         | 
|
| 135 | 
         | 
| 136 | 
         
             
                pipe = pipe.to(device)
         
     | 
| 137 | 
         
             
                pipe.image_encoder = pipe.image_encoder.to(device)
         
     | 
| 138 | 
         
            +
                pipe.enable_model_cpu_offload()
         
     | 
| 139 | 
         
             
                # pipe.enable_xformers_memory_efficient_attention()
         
     | 
| 140 | 
         
             
                # pipe.enable_vae_slicing()
         
     | 
| 141 | 
         | 
| 
         | 
|
| 168 | 
         
             
                    force_zeros_for_empty_prompt=False,
         
     | 
| 169 | 
         
             
                )
         
     | 
| 170 | 
         
             
                pipe = pipe.to(device)
         
     | 
| 171 | 
         
            +
                pipe.enable_model_cpu_offload()
         
     | 
| 172 | 
         
            +
                pipe.enable_xformers_memory_efficient_attention()
         
     | 
| 173 | 
         | 
| 174 | 
         
             
                return pipe
         
     | 
| 175 | 
         | 
    	
        embodied_gen/models/texture_model.py
    CHANGED
    
    | 
         @@ -106,6 +106,6 @@ def build_texture_gen_pipe( 
     | 
|
| 106 | 
         
             
                    pipe.set_ip_adapter_scale([ip_adapt_scale])
         
     | 
| 107 | 
         | 
| 108 | 
         
             
                pipe = pipe.to(device)
         
     | 
| 109 | 
         
            -
                 
     | 
| 110 | 
         | 
| 111 | 
         
             
                return pipe
         
     | 
| 
         | 
|
| 106 | 
         
             
                    pipe.set_ip_adapter_scale([ip_adapt_scale])
         
     | 
| 107 | 
         | 
| 108 | 
         
             
                pipe = pipe.to(device)
         
     | 
| 109 | 
         
            +
                pipe.enable_model_cpu_offload()
         
     | 
| 110 | 
         | 
| 111 | 
         
             
                return pipe
         
     | 
    	
        embodied_gen/scripts/compose_layout.py
    ADDED
    
    | 
         @@ -0,0 +1,73 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Project EmbodiedGen
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 14 | 
         
            +
            # implied. See the License for the specific language governing
         
     | 
| 15 | 
         
            +
            # permissions and limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import json
         
     | 
| 18 | 
         
            +
            import os
         
     | 
| 19 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import tyro
         
     | 
| 22 | 
         
            +
            from embodied_gen.scripts.simulate_sapien import entrypoint as sim_cli
         
     | 
| 23 | 
         
            +
            from embodied_gen.utils.enum import LayoutInfo
         
     | 
| 24 | 
         
            +
            from embodied_gen.utils.geometry import bfs_placement, compose_mesh_scene
         
     | 
| 25 | 
         
            +
            from embodied_gen.utils.log import logger
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            @dataclass
         
     | 
| 29 | 
         
            +
            class LayoutPlacementConfig:
         
     | 
| 30 | 
         
            +
                layout_path: str
         
     | 
| 31 | 
         
            +
                output_dir: str | None = None
         
     | 
| 32 | 
         
            +
                seed: int | None = None
         
     | 
| 33 | 
         
            +
                max_attempts: int = 1000
         
     | 
| 34 | 
         
            +
                output_iscene: bool = False
         
     | 
| 35 | 
         
            +
                insert_robot: bool = False
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            def entrypoint(**kwargs):
         
     | 
| 39 | 
         
            +
                if kwargs is None or len(kwargs) == 0:
         
     | 
| 40 | 
         
            +
                    args = tyro.cli(LayoutPlacementConfig)
         
     | 
| 41 | 
         
            +
                else:
         
     | 
| 42 | 
         
            +
                    args = LayoutPlacementConfig(**kwargs)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                output_dir = (
         
     | 
| 45 | 
         
            +
                    args.output_dir
         
     | 
| 46 | 
         
            +
                    if args.output_dir is not None
         
     | 
| 47 | 
         
            +
                    else os.path.dirname(args.layout_path)
         
     | 
| 48 | 
         
            +
                )
         
     | 
| 49 | 
         
            +
                os.makedirs(output_dir, exist_ok=True)
         
     | 
| 50 | 
         
            +
                out_scene_path = f"{output_dir}/Iscene.glb"
         
     | 
| 51 | 
         
            +
                out_layout_path = f"{output_dir}/layout.json"
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                with open(args.layout_path, "r") as f:
         
     | 
| 54 | 
         
            +
                    layout_info = LayoutInfo.from_dict(json.load(f))
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                layout_info = bfs_placement(layout_info, seed=args.seed)
         
     | 
| 57 | 
         
            +
                with open(out_layout_path, "w") as f:
         
     | 
| 58 | 
         
            +
                    json.dump(layout_info.to_dict(), f, indent=4)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                if args.output_iscene:
         
     | 
| 61 | 
         
            +
                    compose_mesh_scene(layout_info, out_scene_path)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                sim_cli(
         
     | 
| 64 | 
         
            +
                    layout_path=out_layout_path,
         
     | 
| 65 | 
         
            +
                    output_dir=output_dir,
         
     | 
| 66 | 
         
            +
                    robot_name="franka" if args.insert_robot else None,
         
     | 
| 67 | 
         
            +
                )
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                logger.info(f"Layout placement completed in {output_dir}")
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 73 | 
         
            +
                entrypoint()
         
     | 
    	
        embodied_gen/scripts/gen_layout.py
    ADDED
    
    | 
         @@ -0,0 +1,156 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Project EmbodiedGen
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 14 | 
         
            +
            # implied. See the License for the specific language governing
         
     | 
| 15 | 
         
            +
            # permissions and limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import gc
         
     | 
| 18 | 
         
            +
            import json
         
     | 
| 19 | 
         
            +
            import os
         
     | 
| 20 | 
         
            +
            from dataclasses import dataclass, field
         
     | 
| 21 | 
         
            +
            from shutil import copytree
         
     | 
| 22 | 
         
            +
            from time import time
         
     | 
| 23 | 
         
            +
            from typing import Optional
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            import torch
         
     | 
| 26 | 
         
            +
            import tyro
         
     | 
| 27 | 
         
            +
            from embodied_gen.models.layout import build_scene_layout
         
     | 
| 28 | 
         
            +
            from embodied_gen.scripts.simulate_sapien import entrypoint as sim_cli
         
     | 
| 29 | 
         
            +
            from embodied_gen.scripts.textto3d import text_to_3d
         
     | 
| 30 | 
         
            +
            from embodied_gen.utils.config import GptParamsConfig
         
     | 
| 31 | 
         
            +
            from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
         
     | 
| 32 | 
         
            +
            from embodied_gen.utils.geometry import bfs_placement, compose_mesh_scene
         
     | 
| 33 | 
         
            +
            from embodied_gen.utils.gpt_clients import GPT_CLIENT
         
     | 
| 34 | 
         
            +
            from embodied_gen.utils.log import logger
         
     | 
| 35 | 
         
            +
            from embodied_gen.utils.process_media import (
         
     | 
| 36 | 
         
            +
                load_scene_dict,
         
     | 
| 37 | 
         
            +
                parse_text_prompts,
         
     | 
| 38 | 
         
            +
            )
         
     | 
| 39 | 
         
            +
            from embodied_gen.validators.quality_checkers import SemanticMatcher
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            @dataclass
         
     | 
| 45 | 
         
            +
            class LayoutGenConfig:
         
     | 
| 46 | 
         
            +
                task_descs: list[str]
         
     | 
| 47 | 
         
            +
                output_root: str
         
     | 
| 48 | 
         
            +
                bg_list: str = "outputs/bg_scenes/scene_list.txt"
         
     | 
| 49 | 
         
            +
                n_img_sample: int = 3
         
     | 
| 50 | 
         
            +
                text_guidance_scale: float = 7.0
         
     | 
| 51 | 
         
            +
                img_denoise_step: int = 25
         
     | 
| 52 | 
         
            +
                n_image_retry: int = 4
         
     | 
| 53 | 
         
            +
                n_asset_retry: int = 3
         
     | 
| 54 | 
         
            +
                n_pipe_retry: int = 2
         
     | 
| 55 | 
         
            +
                seed_img: Optional[int] = None
         
     | 
| 56 | 
         
            +
                seed_3d: Optional[int] = None
         
     | 
| 57 | 
         
            +
                seed_layout: Optional[int] = None
         
     | 
| 58 | 
         
            +
                keep_intermediate: bool = False
         
     | 
| 59 | 
         
            +
                output_iscene: bool = False
         
     | 
| 60 | 
         
            +
                insert_robot: bool = False
         
     | 
| 61 | 
         
            +
                gpt_params: GptParamsConfig = field(
         
     | 
| 62 | 
         
            +
                    default_factory=lambda: GptParamsConfig(
         
     | 
| 63 | 
         
            +
                        temperature=1.0,
         
     | 
| 64 | 
         
            +
                        top_p=0.95,
         
     | 
| 65 | 
         
            +
                        frequency_penalty=0.3,
         
     | 
| 66 | 
         
            +
                        presence_penalty=0.5,
         
     | 
| 67 | 
         
            +
                    )
         
     | 
| 68 | 
         
            +
                )
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def entrypoint() -> None:
         
     | 
| 72 | 
         
            +
                args = tyro.cli(LayoutGenConfig)
         
     | 
| 73 | 
         
            +
                SCENE_MATCHER = SemanticMatcher(GPT_CLIENT)
         
     | 
| 74 | 
         
            +
                task_descs = parse_text_prompts(args.task_descs)
         
     | 
| 75 | 
         
            +
                scene_dict = load_scene_dict(args.bg_list)
         
     | 
| 76 | 
         
            +
                gpt_params = args.gpt_params.to_dict()
         
     | 
| 77 | 
         
            +
                for idx, task_desc in enumerate(task_descs):
         
     | 
| 78 | 
         
            +
                    logger.info(f"Generate Layout and 3D scene for task: {task_desc}")
         
     | 
| 79 | 
         
            +
                    output_root = f"{args.output_root}/task_{idx:04d}"
         
     | 
| 80 | 
         
            +
                    scene_graph_path = f"{output_root}/scene_tree.jpg"
         
     | 
| 81 | 
         
            +
                    start_time = time()
         
     | 
| 82 | 
         
            +
                    layout_info: LayoutInfo = build_scene_layout(
         
     | 
| 83 | 
         
            +
                        task_desc, scene_graph_path, gpt_params
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
                    prompts_mapping = {v: k for k, v in layout_info.objs_desc.items()}
         
     | 
| 86 | 
         
            +
                    prompts = [
         
     | 
| 87 | 
         
            +
                        v
         
     | 
| 88 | 
         
            +
                        for k, v in layout_info.objs_desc.items()
         
     | 
| 89 | 
         
            +
                        if layout_info.objs_mapping[k] != Scene3DItemEnum.BACKGROUND.value
         
     | 
| 90 | 
         
            +
                    ]
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    for prompt in prompts:
         
     | 
| 93 | 
         
            +
                        node = prompts_mapping[prompt]
         
     | 
| 94 | 
         
            +
                        generation_log = text_to_3d(
         
     | 
| 95 | 
         
            +
                            prompts=[
         
     | 
| 96 | 
         
            +
                                prompt,
         
     | 
| 97 | 
         
            +
                            ],
         
     | 
| 98 | 
         
            +
                            output_root=output_root,
         
     | 
| 99 | 
         
            +
                            asset_names=[
         
     | 
| 100 | 
         
            +
                                node,
         
     | 
| 101 | 
         
            +
                            ],
         
     | 
| 102 | 
         
            +
                            n_img_sample=args.n_img_sample,
         
     | 
| 103 | 
         
            +
                            text_guidance_scale=args.text_guidance_scale,
         
     | 
| 104 | 
         
            +
                            img_denoise_step=args.img_denoise_step,
         
     | 
| 105 | 
         
            +
                            n_image_retry=args.n_image_retry,
         
     | 
| 106 | 
         
            +
                            n_asset_retry=args.n_asset_retry,
         
     | 
| 107 | 
         
            +
                            n_pipe_retry=args.n_pipe_retry,
         
     | 
| 108 | 
         
            +
                            seed_img=args.seed_img,
         
     | 
| 109 | 
         
            +
                            seed_3d=args.seed_3d,
         
     | 
| 110 | 
         
            +
                            keep_intermediate=args.keep_intermediate,
         
     | 
| 111 | 
         
            +
                        )
         
     | 
| 112 | 
         
            +
                        layout_info.assets.update(generation_log["assets"])
         
     | 
| 113 | 
         
            +
                        layout_info.quality.update(generation_log["quality"])
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    # Background GEN (for efficiency, temp use retrieval instead)
         
     | 
| 116 | 
         
            +
                    bg_node = layout_info.relation[Scene3DItemEnum.BACKGROUND.value]
         
     | 
| 117 | 
         
            +
                    text = layout_info.objs_desc[bg_node]
         
     | 
| 118 | 
         
            +
                    match_key = SCENE_MATCHER.query(text, str(scene_dict))
         
     | 
| 119 | 
         
            +
                    match_scene_path = f"{os.path.dirname(args.bg_list)}/{match_key}"
         
     | 
| 120 | 
         
            +
                    bg_save_dir = os.path.join(output_root, "background")
         
     | 
| 121 | 
         
            +
                    copytree(match_scene_path, bg_save_dir, dirs_exist_ok=True)
         
     | 
| 122 | 
         
            +
                    layout_info.assets[bg_node] = bg_save_dir
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    # BFS layout placement.
         
     | 
| 125 | 
         
            +
                    layout_info = bfs_placement(
         
     | 
| 126 | 
         
            +
                        layout_info,
         
     | 
| 127 | 
         
            +
                        limit_reach_range=True if args.insert_robot else False,
         
     | 
| 128 | 
         
            +
                        seed=args.seed_layout,
         
     | 
| 129 | 
         
            +
                    )
         
     | 
| 130 | 
         
            +
                    layout_path = f"{output_root}/layout.json"
         
     | 
| 131 | 
         
            +
                    with open(layout_path, "w") as f:
         
     | 
| 132 | 
         
            +
                        json.dump(layout_info.to_dict(), f, indent=4)
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    if args.output_iscene:
         
     | 
| 135 | 
         
            +
                        compose_mesh_scene(layout_info, f"{output_root}/Iscene.glb")
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    sim_cli(
         
     | 
| 138 | 
         
            +
                        layout_path=layout_path,
         
     | 
| 139 | 
         
            +
                        output_dir=output_root,
         
     | 
| 140 | 
         
            +
                        robot_name="franka" if args.insert_robot else None,
         
     | 
| 141 | 
         
            +
                    )
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 144 | 
         
            +
                    gc.collect()
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    elapsed_time = (time() - start_time) / 60
         
     | 
| 147 | 
         
            +
                    logger.info(
         
     | 
| 148 | 
         
            +
                        f"Layout generation done for {scene_graph_path}, layout result "
         
     | 
| 149 | 
         
            +
                        f"in {layout_path}, finished in {elapsed_time:.2f} mins."
         
     | 
| 150 | 
         
            +
                    )
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                logger.info(f"All tasks completed in {args.output_root}")
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 156 | 
         
            +
                entrypoint()
         
     | 
    	
        embodied_gen/scripts/imageto3d.py
    CHANGED
    
    | 
         @@ -58,7 +58,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" 
     | 
|
| 58 | 
         
             
            os.environ["SPCONV_ALGO"] = "native"
         
     | 
| 59 | 
         
             
            random.seed(0)
         
     | 
| 60 | 
         | 
| 61 | 
         
            -
            logger.info("Loading Models...")
         
     | 
| 62 | 
         
             
            DELIGHT = DelightingModel()
         
     | 
| 63 | 
         
             
            IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
         
     | 
| 64 | 
         
             
            RBG_REMOVER = RembgRemover()
         
     | 
| 
         @@ -107,6 +107,7 @@ def parse_args(): 
     | 
|
| 107 | 
         
             
                    type=int,
         
     | 
| 108 | 
         
             
                    default=2,
         
     | 
| 109 | 
         
             
                )
         
     | 
| 
         | 
|
| 110 | 
         
             
                args, unknown = parser.parse_known_args()
         
     | 
| 111 | 
         | 
| 112 | 
         
             
                return args
         
     | 
| 
         @@ -151,6 +152,9 @@ def entrypoint(**kwargs): 
     | 
|
| 151 | 
         
             
                        seg_image.save(seg_path)
         
     | 
| 152 | 
         | 
| 153 | 
         
             
                        seed = args.seed
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 154 | 
         
             
                        for try_idx in range(args.n_retry):
         
     | 
| 155 | 
         
             
                            logger.info(
         
     | 
| 156 | 
         
             
                                f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
         
     | 
| 
         @@ -207,7 +211,9 @@ def entrypoint(**kwargs): 
     | 
|
| 207 | 
         
             
                            color_path = os.path.join(output_root, "color.png")
         
     | 
| 208 | 
         
             
                            render_gs_api(aligned_gs_path, color_path)
         
     | 
| 209 | 
         | 
| 210 | 
         
            -
                            geo_flag, geo_result = GEO_CHECKER( 
     | 
| 
         | 
|
| 
         | 
|
| 211 | 
         
             
                            logger.warning(
         
     | 
| 212 | 
         
             
                                f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
         
     | 
| 213 | 
         
             
                            )
         
     | 
| 
         @@ -246,7 +252,11 @@ def entrypoint(**kwargs): 
     | 
|
| 246 | 
         
             
                        mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
         
     | 
| 247 | 
         
             
                        mesh.export(mesh_glb_path)
         
     | 
| 248 | 
         | 
| 249 | 
         
            -
                        urdf_convertor = URDFGenerator( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 250 | 
         
             
                        asset_attrs = {
         
     | 
| 251 | 
         
             
                            "version": VERSION,
         
     | 
| 252 | 
         
             
                            "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
         
     | 
| 
         | 
|
| 58 | 
         
             
            os.environ["SPCONV_ALGO"] = "native"
         
     | 
| 59 | 
         
             
            random.seed(0)
         
     | 
| 60 | 
         | 
| 61 | 
         
            +
            logger.info("Loading Image3D Models...")
         
     | 
| 62 | 
         
             
            DELIGHT = DelightingModel()
         
     | 
| 63 | 
         
             
            IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
         
     | 
| 64 | 
         
             
            RBG_REMOVER = RembgRemover()
         
     | 
| 
         | 
|
| 107 | 
         
             
                    type=int,
         
     | 
| 108 | 
         
             
                    default=2,
         
     | 
| 109 | 
         
             
                )
         
     | 
| 110 | 
         
            +
                parser.add_argument("--disable_decompose_convex", action="store_true")
         
     | 
| 111 | 
         
             
                args, unknown = parser.parse_known_args()
         
     | 
| 112 | 
         | 
| 113 | 
         
             
                return args
         
     | 
| 
         | 
|
| 152 | 
         
             
                        seg_image.save(seg_path)
         
     | 
| 153 | 
         | 
| 154 | 
         
             
                        seed = args.seed
         
     | 
| 155 | 
         
            +
                        asset_node = "unknown"
         
     | 
| 156 | 
         
            +
                        if isinstance(args.asset_type, list) and args.asset_type[idx]:
         
     | 
| 157 | 
         
            +
                            asset_node = args.asset_type[idx]
         
     | 
| 158 | 
         
             
                        for try_idx in range(args.n_retry):
         
     | 
| 159 | 
         
             
                            logger.info(
         
     | 
| 160 | 
         
             
                                f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
         
     | 
| 
         | 
|
| 211 | 
         
             
                            color_path = os.path.join(output_root, "color.png")
         
     | 
| 212 | 
         
             
                            render_gs_api(aligned_gs_path, color_path)
         
     | 
| 213 | 
         | 
| 214 | 
         
            +
                            geo_flag, geo_result = GEO_CHECKER(
         
     | 
| 215 | 
         
            +
                                [color_path], text=asset_node
         
     | 
| 216 | 
         
            +
                            )
         
     | 
| 217 | 
         
             
                            logger.warning(
         
     | 
| 218 | 
         
             
                                f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
         
     | 
| 219 | 
         
             
                            )
         
     | 
| 
         | 
|
| 252 | 
         
             
                        mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
         
     | 
| 253 | 
         
             
                        mesh.export(mesh_glb_path)
         
     | 
| 254 | 
         | 
| 255 | 
         
            +
                        urdf_convertor = URDFGenerator(
         
     | 
| 256 | 
         
            +
                            GPT_CLIENT,
         
     | 
| 257 | 
         
            +
                            render_view_num=4,
         
     | 
| 258 | 
         
            +
                            decompose_convex=not args.disable_decompose_convex,
         
     | 
| 259 | 
         
            +
                        )
         
     | 
| 260 | 
         
             
                        asset_attrs = {
         
     | 
| 261 | 
         
             
                            "version": VERSION,
         
     | 
| 262 | 
         
             
                            "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
         
     | 
    	
        embodied_gen/scripts/parallel_sim.py
    ADDED
    
    | 
         @@ -0,0 +1,148 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Project EmbodiedGen
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 14 | 
         
            +
            # implied. See the License for the specific language governing
         
     | 
| 15 | 
         
            +
            # permissions and limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from embodied_gen.utils.monkey_patches import monkey_patch_maniskill
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            monkey_patch_maniskill()
         
     | 
| 21 | 
         
            +
            import json
         
     | 
| 22 | 
         
            +
            from collections import defaultdict
         
     | 
| 23 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 24 | 
         
            +
            from typing import Literal
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            import gymnasium as gym
         
     | 
| 27 | 
         
            +
            import numpy as np
         
     | 
| 28 | 
         
            +
            import torch
         
     | 
| 29 | 
         
            +
            import tyro
         
     | 
| 30 | 
         
            +
            from mani_skill.utils.wrappers import RecordEpisode
         
     | 
| 31 | 
         
            +
            from tqdm import tqdm
         
     | 
| 32 | 
         
            +
            import embodied_gen.envs.pick_embodiedgen
         
     | 
| 33 | 
         
            +
            from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
         
     | 
| 34 | 
         
            +
            from embodied_gen.utils.log import logger
         
     | 
| 35 | 
         
            +
            from embodied_gen.utils.simulation import FrankaPandaGrasper
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            @dataclass
         
     | 
| 39 | 
         
            +
            class ParallelSimConfig:
         
     | 
| 40 | 
         
            +
                """CLI parameters for Parallel Sapien simulation."""
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                # Environment configuration
         
     | 
| 43 | 
         
            +
                layout_file: str
         
     | 
| 44 | 
         
            +
                """Path to the layout JSON file"""
         
     | 
| 45 | 
         
            +
                output_dir: str
         
     | 
| 46 | 
         
            +
                """Directory to save recorded videos"""
         
     | 
| 47 | 
         
            +
                gym_env_name: str = "PickEmbodiedGen-v1"
         
     | 
| 48 | 
         
            +
                """Name of the Gym environment to use"""
         
     | 
| 49 | 
         
            +
                num_envs: int = 4
         
     | 
| 50 | 
         
            +
                """Number of parallel environments"""
         
     | 
| 51 | 
         
            +
                render_mode: Literal["rgb_array", "hybrid"] = "hybrid"
         
     | 
| 52 | 
         
            +
                """Rendering mode: rgb_array or hybrid"""
         
     | 
| 53 | 
         
            +
                enable_shadow: bool = True
         
     | 
| 54 | 
         
            +
                """Whether to enable shadows in rendering"""
         
     | 
| 55 | 
         
            +
                control_mode: str = "pd_joint_pos"
         
     | 
| 56 | 
         
            +
                """Control mode for the agent"""
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                # Recording configuration
         
     | 
| 59 | 
         
            +
                max_steps_per_video: int = 1000
         
     | 
| 60 | 
         
            +
                """Maximum steps to record per video"""
         
     | 
| 61 | 
         
            +
                save_trajectory: bool = False
         
     | 
| 62 | 
         
            +
                """Whether to save trajectory data"""
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                # Simulation parameters
         
     | 
| 65 | 
         
            +
                seed: int = 0
         
     | 
| 66 | 
         
            +
                """Random seed for environment reset"""
         
     | 
| 67 | 
         
            +
                warmup_steps: int = 50
         
     | 
| 68 | 
         
            +
                """Number of warmup steps before action computation"""
         
     | 
| 69 | 
         
            +
                reach_target_only: bool = True
         
     | 
| 70 | 
         
            +
                """Whether to only reach target without full action"""
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def entrypoint(**kwargs):
         
     | 
| 74 | 
         
            +
                if kwargs is None or len(kwargs) == 0:
         
     | 
| 75 | 
         
            +
                    cfg = tyro.cli(ParallelSimConfig)
         
     | 
| 76 | 
         
            +
                else:
         
     | 
| 77 | 
         
            +
                    cfg = ParallelSimConfig(**kwargs)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                env = gym.make(
         
     | 
| 80 | 
         
            +
                    cfg.gym_env_name,
         
     | 
| 81 | 
         
            +
                    num_envs=cfg.num_envs,
         
     | 
| 82 | 
         
            +
                    render_mode=cfg.render_mode,
         
     | 
| 83 | 
         
            +
                    enable_shadow=cfg.enable_shadow,
         
     | 
| 84 | 
         
            +
                    layout_file=cfg.layout_file,
         
     | 
| 85 | 
         
            +
                    control_mode=cfg.control_mode,
         
     | 
| 86 | 
         
            +
                )
         
     | 
| 87 | 
         
            +
                env = RecordEpisode(
         
     | 
| 88 | 
         
            +
                    env,
         
     | 
| 89 | 
         
            +
                    cfg.output_dir,
         
     | 
| 90 | 
         
            +
                    max_steps_per_video=cfg.max_steps_per_video,
         
     | 
| 91 | 
         
            +
                    save_trajectory=cfg.save_trajectory,
         
     | 
| 92 | 
         
            +
                )
         
     | 
| 93 | 
         
            +
                env.reset(seed=cfg.seed)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                default_action = env.unwrapped.agent.init_qpos[:, :8]
         
     | 
| 96 | 
         
            +
                for _ in tqdm(range(cfg.warmup_steps), desc="SIM Warmup"):
         
     | 
| 97 | 
         
            +
                    # action = env.action_space.sample() # Random action
         
     | 
| 98 | 
         
            +
                    obs, reward, terminated, truncated, info = env.step(default_action)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                grasper = FrankaPandaGrasper(
         
     | 
| 101 | 
         
            +
                    env.unwrapped.agent,
         
     | 
| 102 | 
         
            +
                    env.unwrapped.sim_config.control_freq,
         
     | 
| 103 | 
         
            +
                )
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                layout_data = LayoutInfo.from_dict(json.load(open(cfg.layout_file, "r")))
         
     | 
| 106 | 
         
            +
                actions = defaultdict(list)
         
     | 
| 107 | 
         
            +
                # Plan Grasp reach pose for each manipulated object in each env.
         
     | 
| 108 | 
         
            +
                for env_idx in range(env.num_envs):
         
     | 
| 109 | 
         
            +
                    actors = env.unwrapped.env_actors[f"env{env_idx}"]
         
     | 
| 110 | 
         
            +
                    for node in layout_data.relation[
         
     | 
| 111 | 
         
            +
                        Scene3DItemEnum.MANIPULATED_OBJS.value
         
     | 
| 112 | 
         
            +
                    ]:
         
     | 
| 113 | 
         
            +
                        action = grasper.compute_grasp_action(
         
     | 
| 114 | 
         
            +
                            actor=actors[node]._objs[0],
         
     | 
| 115 | 
         
            +
                            reach_target_only=True,
         
     | 
| 116 | 
         
            +
                            env_idx=env_idx,
         
     | 
| 117 | 
         
            +
                        )
         
     | 
| 118 | 
         
            +
                        actions[node].append(action)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                # Excute the planned actions for each manipulated object in each env.
         
     | 
| 121 | 
         
            +
                for node in actions:
         
     | 
| 122 | 
         
            +
                    max_env_steps = 0
         
     | 
| 123 | 
         
            +
                    for env_idx in range(env.num_envs):
         
     | 
| 124 | 
         
            +
                        if actions[node][env_idx] is None:
         
     | 
| 125 | 
         
            +
                            continue
         
     | 
| 126 | 
         
            +
                        max_env_steps = max(max_env_steps, len(actions[node][env_idx]))
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    action_tensor = np.ones(
         
     | 
| 129 | 
         
            +
                        (max_env_steps, env.num_envs, env.action_space.shape[-1])
         
     | 
| 130 | 
         
            +
                    )
         
     | 
| 131 | 
         
            +
                    action_tensor *= default_action[None, ...]
         
     | 
| 132 | 
         
            +
                    for env_idx in range(env.num_envs):
         
     | 
| 133 | 
         
            +
                        action = actions[node][env_idx]
         
     | 
| 134 | 
         
            +
                        if action is None:
         
     | 
| 135 | 
         
            +
                            continue
         
     | 
| 136 | 
         
            +
                        action_tensor[: len(action), env_idx, :] = action
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    for step in tqdm(range(max_env_steps), desc=f"Grasping: {node}"):
         
     | 
| 139 | 
         
            +
                        action = torch.Tensor(action_tensor[step]).to(env.unwrapped.device)
         
     | 
| 140 | 
         
            +
                        env.unwrapped.agent.set_action(action)
         
     | 
| 141 | 
         
            +
                        obs, reward, terminated, truncated, info = env.step(action)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                env.close()
         
     | 
| 144 | 
         
            +
                logger.info(f"Results saved in {cfg.output_dir}")
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 148 | 
         
            +
                entrypoint()
         
     | 
    	
        embodied_gen/scripts/simulate_sapien.py
    ADDED
    
    | 
         @@ -0,0 +1,195 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Project EmbodiedGen
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 14 | 
         
            +
            # implied. See the License for the specific language governing
         
     | 
| 15 | 
         
            +
            # permissions and limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import json
         
     | 
| 19 | 
         
            +
            import os
         
     | 
| 20 | 
         
            +
            from collections import defaultdict
         
     | 
| 21 | 
         
            +
            from dataclasses import dataclass, field
         
     | 
| 22 | 
         
            +
            from typing import Literal
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            import imageio
         
     | 
| 25 | 
         
            +
            import numpy as np
         
     | 
| 26 | 
         
            +
            import torch
         
     | 
| 27 | 
         
            +
            import tyro
         
     | 
| 28 | 
         
            +
            from tqdm import tqdm
         
     | 
| 29 | 
         
            +
            from embodied_gen.models.gs_model import GaussianOperator
         
     | 
| 30 | 
         
            +
            from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
         
     | 
| 31 | 
         
            +
            from embodied_gen.utils.geometry import quaternion_multiply
         
     | 
| 32 | 
         
            +
            from embodied_gen.utils.log import logger
         
     | 
| 33 | 
         
            +
            from embodied_gen.utils.process_media import alpha_blend_rgba
         
     | 
| 34 | 
         
            +
            from embodied_gen.utils.simulation import (
         
     | 
| 35 | 
         
            +
                SIM_COORD_ALIGN,
         
     | 
| 36 | 
         
            +
                FrankaPandaGrasper,
         
     | 
| 37 | 
         
            +
                SapienSceneManager,
         
     | 
| 38 | 
         
            +
                load_assets_from_layout_file,
         
     | 
| 39 | 
         
            +
                load_mani_skill_robot,
         
     | 
| 40 | 
         
            +
                render_images,
         
     | 
| 41 | 
         
            +
            )
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            @dataclass
         
     | 
| 45 | 
         
            +
            class SapienSimConfig:
         
     | 
| 46 | 
         
            +
                # Simulation settings.
         
     | 
| 47 | 
         
            +
                layout_path: str
         
     | 
| 48 | 
         
            +
                output_dir: str
         
     | 
| 49 | 
         
            +
                sim_freq: int = 200
         
     | 
| 50 | 
         
            +
                sim_step: int = 400
         
     | 
| 51 | 
         
            +
                z_offset: float = 0.004
         
     | 
| 52 | 
         
            +
                init_quat: list[float] = field(
         
     | 
| 53 | 
         
            +
                    default_factory=lambda: [0.7071, 0, 0, 0.7071]
         
     | 
| 54 | 
         
            +
                )  # xyzw
         
     | 
| 55 | 
         
            +
                device: str = "cuda"
         
     | 
| 56 | 
         
            +
                control_freq: int = 50
         
     | 
| 57 | 
         
            +
                insert_robot: bool = False
         
     | 
| 58 | 
         
            +
                # Camera settings.
         
     | 
| 59 | 
         
            +
                render_interval: int = 10
         
     | 
| 60 | 
         
            +
                num_cameras: int = 3
         
     | 
| 61 | 
         
            +
                camera_radius: float = 0.9
         
     | 
| 62 | 
         
            +
                camera_height: float = 1.1
         
     | 
| 63 | 
         
            +
                image_hw: tuple[int, int] = (512, 512)
         
     | 
| 64 | 
         
            +
                ray_tracing: bool = True
         
     | 
| 65 | 
         
            +
                fovy_deg: float = 75.0
         
     | 
| 66 | 
         
            +
                camera_target_pt: list[float] = field(
         
     | 
| 67 | 
         
            +
                    default_factory=lambda: [0.0, 0.0, 0.9]
         
     | 
| 68 | 
         
            +
                )
         
     | 
| 69 | 
         
            +
                render_keys: list[
         
     | 
| 70 | 
         
            +
                    Literal[
         
     | 
| 71 | 
         
            +
                        "Color", "Foreground", "Segmentation", "Normal", "Mask", "Depth"
         
     | 
| 72 | 
         
            +
                    ]
         
     | 
| 73 | 
         
            +
                ] = field(default_factory=lambda: ["Foreground"])
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            def entrypoint(**kwargs):
         
     | 
| 77 | 
         
            +
                if kwargs is None or len(kwargs) == 0:
         
     | 
| 78 | 
         
            +
                    cfg = tyro.cli(SapienSimConfig)
         
     | 
| 79 | 
         
            +
                else:
         
     | 
| 80 | 
         
            +
                    cfg = SapienSimConfig(**kwargs)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                scene_manager = SapienSceneManager(
         
     | 
| 83 | 
         
            +
                    cfg.sim_freq, ray_tracing=cfg.ray_tracing
         
     | 
| 84 | 
         
            +
                )
         
     | 
| 85 | 
         
            +
                _ = scene_manager.initialize_circular_cameras(
         
     | 
| 86 | 
         
            +
                    num_cameras=cfg.num_cameras,
         
     | 
| 87 | 
         
            +
                    radius=cfg.camera_radius,
         
     | 
| 88 | 
         
            +
                    height=cfg.camera_height,
         
     | 
| 89 | 
         
            +
                    target_pt=cfg.camera_target_pt,
         
     | 
| 90 | 
         
            +
                    image_hw=cfg.image_hw,
         
     | 
| 91 | 
         
            +
                    fovy_deg=cfg.fovy_deg,
         
     | 
| 92 | 
         
            +
                )
         
     | 
| 93 | 
         
            +
                with open(cfg.layout_path, "r") as f:
         
     | 
| 94 | 
         
            +
                    layout_data = json.load(f)
         
     | 
| 95 | 
         
            +
                    layout_data: LayoutInfo = LayoutInfo.from_dict(layout_data)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                actors = load_assets_from_layout_file(
         
     | 
| 98 | 
         
            +
                    scene_manager.scene,
         
     | 
| 99 | 
         
            +
                    layout_data,
         
     | 
| 100 | 
         
            +
                    cfg.z_offset,
         
     | 
| 101 | 
         
            +
                    cfg.init_quat,
         
     | 
| 102 | 
         
            +
                )
         
     | 
| 103 | 
         
            +
                agent = load_mani_skill_robot(
         
     | 
| 104 | 
         
            +
                    scene_manager.scene, layout_data, cfg.control_freq
         
     | 
| 105 | 
         
            +
                )
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                frames = defaultdict(list)
         
     | 
| 108 | 
         
            +
                image_cnt = 0
         
     | 
| 109 | 
         
            +
                for step in tqdm(range(cfg.sim_step), desc="Simulation"):
         
     | 
| 110 | 
         
            +
                    scene_manager.scene.step()
         
     | 
| 111 | 
         
            +
                    agent.reset(agent.init_qpos)
         
     | 
| 112 | 
         
            +
                    if step % cfg.render_interval != 0:
         
     | 
| 113 | 
         
            +
                        continue
         
     | 
| 114 | 
         
            +
                    scene_manager.scene.update_render()
         
     | 
| 115 | 
         
            +
                    image_cnt += 1
         
     | 
| 116 | 
         
            +
                    for camera in scene_manager.cameras:
         
     | 
| 117 | 
         
            +
                        camera.take_picture()
         
     | 
| 118 | 
         
            +
                        images = render_images(camera, cfg.render_keys)
         
     | 
| 119 | 
         
            +
                        frames[camera.name].append(images)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                actions = dict()
         
     | 
| 122 | 
         
            +
                if cfg.insert_robot:
         
     | 
| 123 | 
         
            +
                    grasper = FrankaPandaGrasper(
         
     | 
| 124 | 
         
            +
                        agent,
         
     | 
| 125 | 
         
            +
                        cfg.control_freq,
         
     | 
| 126 | 
         
            +
                    )
         
     | 
| 127 | 
         
            +
                    for node in layout_data.relation[
         
     | 
| 128 | 
         
            +
                        Scene3DItemEnum.MANIPULATED_OBJS.value
         
     | 
| 129 | 
         
            +
                    ]:
         
     | 
| 130 | 
         
            +
                        actions[node] = grasper.compute_grasp_action(
         
     | 
| 131 | 
         
            +
                            actor=actors[node], reach_target_only=True
         
     | 
| 132 | 
         
            +
                        )
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                if "Foreground" not in cfg.render_keys:
         
     | 
| 135 | 
         
            +
                    return
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                bg_node = layout_data.relation[Scene3DItemEnum.BACKGROUND.value]
         
     | 
| 138 | 
         
            +
                gs_path = f"{layout_data.assets[bg_node]}/gs_model.ply"
         
     | 
| 139 | 
         
            +
                gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
         
     | 
| 140 | 
         
            +
                x, y, z, qx, qy, qz, qw = layout_data.position[bg_node]
         
     | 
| 141 | 
         
            +
                qx, qy, qz, qw = quaternion_multiply([qx, qy, qz, qw], cfg.init_quat)
         
     | 
| 142 | 
         
            +
                init_pose = torch.tensor([x, y, z, qx, qy, qz, qw])
         
     | 
| 143 | 
         
            +
                gs_model = gs_model.get_gaussians(instance_pose=init_pose)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                bg_images = dict()
         
     | 
| 146 | 
         
            +
                for camera in scene_manager.cameras:
         
     | 
| 147 | 
         
            +
                    Ks = camera.get_intrinsic_matrix()
         
     | 
| 148 | 
         
            +
                    c2w = camera.get_model_matrix()
         
     | 
| 149 | 
         
            +
                    c2w = c2w @ SIM_COORD_ALIGN
         
     | 
| 150 | 
         
            +
                    result = gs_model.render(
         
     | 
| 151 | 
         
            +
                        torch.tensor(c2w, dtype=torch.float32).to(cfg.device),
         
     | 
| 152 | 
         
            +
                        torch.tensor(Ks, dtype=torch.float32).to(cfg.device),
         
     | 
| 153 | 
         
            +
                        image_width=cfg.image_hw[1],
         
     | 
| 154 | 
         
            +
                        image_height=cfg.image_hw[0],
         
     | 
| 155 | 
         
            +
                    )
         
     | 
| 156 | 
         
            +
                    bg_images[camera.name] = result.rgb[..., ::-1]
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                video_frames = []
         
     | 
| 159 | 
         
            +
                for camera in scene_manager.cameras:
         
     | 
| 160 | 
         
            +
                    # Scene rendering
         
     | 
| 161 | 
         
            +
                    for step in range(image_cnt):
         
     | 
| 162 | 
         
            +
                        rgba = alpha_blend_rgba(
         
     | 
| 163 | 
         
            +
                            frames[camera.name][step]["Foreground"],
         
     | 
| 164 | 
         
            +
                            bg_images[camera.name],
         
     | 
| 165 | 
         
            +
                        )
         
     | 
| 166 | 
         
            +
                        video_frames.append(np.array(rgba))
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    # Grasp rendering
         
     | 
| 169 | 
         
            +
                    for node in actions:
         
     | 
| 170 | 
         
            +
                        if actions[node] is None:
         
     | 
| 171 | 
         
            +
                            continue
         
     | 
| 172 | 
         
            +
                        for action in tqdm(actions[node]):
         
     | 
| 173 | 
         
            +
                            grasp_frames = scene_manager.step_action(
         
     | 
| 174 | 
         
            +
                                agent,
         
     | 
| 175 | 
         
            +
                                torch.Tensor(action[None, ...]),
         
     | 
| 176 | 
         
            +
                                scene_manager.cameras,
         
     | 
| 177 | 
         
            +
                                cfg.render_keys,
         
     | 
| 178 | 
         
            +
                                sim_steps_per_control=cfg.sim_freq // cfg.control_freq,
         
     | 
| 179 | 
         
            +
                            )
         
     | 
| 180 | 
         
            +
                            rgba = alpha_blend_rgba(
         
     | 
| 181 | 
         
            +
                                grasp_frames[camera.name][0]["Foreground"],
         
     | 
| 182 | 
         
            +
                                bg_images[camera.name],
         
     | 
| 183 | 
         
            +
                            )
         
     | 
| 184 | 
         
            +
                            video_frames.append(np.array(rgba))
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                        agent.reset(agent.init_qpos)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                os.makedirs(cfg.output_dir, exist_ok=True)
         
     | 
| 189 | 
         
            +
                video_path = f"{cfg.output_dir}/Iscene.mp4"
         
     | 
| 190 | 
         
            +
                imageio.mimsave(video_path, video_frames, fps=30)
         
     | 
| 191 | 
         
            +
                logger.info(f"Interative 3D Scene Visualization saved in {video_path}")
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 195 | 
         
            +
                entrypoint()
         
     | 
    	
        embodied_gen/scripts/textto3d.py
    CHANGED
    
    | 
         @@ -42,7 +42,7 @@ from embodied_gen.validators.quality_checkers import ( 
     | 
|
| 42 | 
         
             
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         
     | 
| 43 | 
         
             
            random.seed(0)
         
     | 
| 44 | 
         | 
| 45 | 
         
            -
            logger.info("Loading  
     | 
| 46 | 
         
             
            SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
         
     | 
| 47 | 
         
             
            SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
         
     | 
| 48 | 
         
             
            TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
         
     | 
| 
         @@ -170,6 +170,7 @@ def text_to_3d(**kwargs) -> dict: 
     | 
|
| 170 | 
         
             
                            seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
         
     | 
| 171 | 
         
             
                            n_retry=args.n_asset_retry,
         
     | 
| 172 | 
         
             
                            keep_intermediate=args.keep_intermediate,
         
     | 
| 
         | 
|
| 173 | 
         
             
                        )
         
     | 
| 174 | 
         
             
                        mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
         
     | 
| 175 | 
         
             
                        image_path = render_asset3d(
         
     | 
| 
         @@ -270,6 +271,7 @@ def parse_args(): 
     | 
|
| 270 | 
         
             
                    help="Random seed for 3D generation",
         
     | 
| 271 | 
         
             
                )
         
     | 
| 272 | 
         
             
                parser.add_argument("--keep_intermediate", action="store_true")
         
     | 
| 
         | 
|
| 273 | 
         | 
| 274 | 
         
             
                args, unknown = parser.parse_known_args()
         
     | 
| 275 | 
         | 
| 
         | 
|
| 42 | 
         
             
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         
     | 
| 43 | 
         
             
            random.seed(0)
         
     | 
| 44 | 
         | 
| 45 | 
         
            +
            logger.info("Loading TEXT2IMG_MODEL...")
         
     | 
| 46 | 
         
             
            SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
         
     | 
| 47 | 
         
             
            SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
         
     | 
| 48 | 
         
             
            TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
         
     | 
| 
         | 
|
| 170 | 
         
             
                            seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
         
     | 
| 171 | 
         
             
                            n_retry=args.n_asset_retry,
         
     | 
| 172 | 
         
             
                            keep_intermediate=args.keep_intermediate,
         
     | 
| 173 | 
         
            +
                            disable_decompose_convex=args.disable_decompose_convex,
         
     | 
| 174 | 
         
             
                        )
         
     | 
| 175 | 
         
             
                        mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
         
     | 
| 176 | 
         
             
                        image_path = render_asset3d(
         
     | 
| 
         | 
|
| 271 | 
         
             
                    help="Random seed for 3D generation",
         
     | 
| 272 | 
         
             
                )
         
     | 
| 273 | 
         
             
                parser.add_argument("--keep_intermediate", action="store_true")
         
     | 
| 274 | 
         
            +
                parser.add_argument("--disable_decompose_convex", action="store_true")
         
     | 
| 275 | 
         | 
| 276 | 
         
             
                args, unknown = parser.parse_known_args()
         
     | 
| 277 | 
         | 
    	
        embodied_gen/scripts/textto3d.sh
    CHANGED
    
    | 
         @@ -81,6 +81,7 @@ done 
     | 
|
| 81 | 
         | 
| 82 | 
         | 
| 83 | 
         
             
            # Step 1: Text-to-Image
         
     | 
| 
         | 
|
| 84 | 
         
             
            eval python3 embodied_gen/scripts/text2image.py \
         
     | 
| 85 | 
         
             
                --prompts ${prompt_args} \
         
     | 
| 86 | 
         
             
                --output_root "${output_root}/images" \
         
     | 
| 
         | 
|
| 81 | 
         | 
| 82 | 
         | 
| 83 | 
         
             
            # Step 1: Text-to-Image
         
     | 
| 84 | 
         
            +
            echo ${prompt_args}
         
     | 
| 85 | 
         
             
            eval python3 embodied_gen/scripts/text2image.py \
         
     | 
| 86 | 
         
             
                --prompts ${prompt_args} \
         
     | 
| 87 | 
         
             
                --output_root "${output_root}/images" \
         
     | 
    	
        embodied_gen/trainer/gsplat_trainer.py
    CHANGED
    
    | 
         @@ -617,7 +617,7 @@ class Runner: 
     | 
|
| 617 | 
         
             
                    for rgb, depth in images_cache:
         
     | 
| 618 | 
         
             
                        depth_normalized = torch.clip(
         
     | 
| 619 | 
         
             
                            (depth - depth_global_min)
         
     | 
| 620 | 
         
            -
                            / (depth_global_max - depth_global_min),
         
     | 
| 621 | 
         
             
                            0,
         
     | 
| 622 | 
         
             
                            1,
         
     | 
| 623 | 
         
             
                        )
         
     | 
| 
         | 
|
| 617 | 
         
             
                    for rgb, depth in images_cache:
         
     | 
| 618 | 
         
             
                        depth_normalized = torch.clip(
         
     | 
| 619 | 
         
             
                            (depth - depth_global_min)
         
     | 
| 620 | 
         
            +
                            / (depth_global_max - depth_global_min + 1e-8),
         
     | 
| 621 | 
         
             
                            0,
         
     | 
| 622 | 
         
             
                            1,
         
     | 
| 623 | 
         
             
                        )
         
     | 
    	
        embodied_gen/trainer/pono2mesh_trainer.py
    CHANGED
    
    | 
         @@ -30,7 +30,7 @@ from kornia.morphology import dilation 
     | 
|
| 30 | 
         
             
            from PIL import Image
         
     | 
| 31 | 
         
             
            from embodied_gen.models.sr_model import ImageRealESRGAN
         
     | 
| 32 | 
         
             
            from embodied_gen.utils.config import Pano2MeshSRConfig
         
     | 
| 33 | 
         
            -
            from embodied_gen.utils. 
     | 
| 34 | 
         
             
            from embodied_gen.utils.log import logger
         
     | 
| 35 | 
         
             
            from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
         
     | 
| 36 | 
         
             
            from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
         
     | 
| 
         | 
|
| 30 | 
         
             
            from PIL import Image
         
     | 
| 31 | 
         
             
            from embodied_gen.models.sr_model import ImageRealESRGAN
         
     | 
| 32 | 
         
             
            from embodied_gen.utils.config import Pano2MeshSRConfig
         
     | 
| 33 | 
         
            +
            from embodied_gen.utils.geometry import compute_pinhole_intrinsics
         
     | 
| 34 | 
         
             
            from embodied_gen.utils.log import logger
         
     | 
| 35 | 
         
             
            from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
         
     | 
| 36 | 
         
             
            from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
         
     | 
    	
        embodied_gen/utils/config.py
    CHANGED
    
    | 
         @@ -17,15 +17,27 @@ 
     | 
|
| 17 | 
         
             
            from dataclasses import dataclass, field
         
     | 
| 18 | 
         
             
            from typing import List, Optional, Union
         
     | 
| 19 | 
         | 
| 
         | 
|
| 20 | 
         
             
            from gsplat.strategy import DefaultStrategy, MCMCStrategy
         
     | 
| 21 | 
         
             
            from typing_extensions import Literal, assert_never
         
     | 
| 22 | 
         | 
| 23 | 
         
             
            __all__ = [
         
     | 
| 
         | 
|
| 24 | 
         
             
                "Pano2MeshSRConfig",
         
     | 
| 25 | 
         
             
                "GsplatTrainConfig",
         
     | 
| 26 | 
         
             
            ]
         
     | 
| 27 | 
         | 
| 28 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 29 | 
         
             
            @dataclass
         
     | 
| 30 | 
         
             
            class Pano2MeshSRConfig:
         
     | 
| 31 | 
         
             
                mesh_file: str = "mesh_model.ply"
         
     | 
| 
         | 
|
| 17 | 
         
             
            from dataclasses import dataclass, field
         
     | 
| 18 | 
         
             
            from typing import List, Optional, Union
         
     | 
| 19 | 
         | 
| 20 | 
         
            +
            from dataclasses_json import DataClassJsonMixin
         
     | 
| 21 | 
         
             
            from gsplat.strategy import DefaultStrategy, MCMCStrategy
         
     | 
| 22 | 
         
             
            from typing_extensions import Literal, assert_never
         
     | 
| 23 | 
         | 
| 24 | 
         
             
            __all__ = [
         
     | 
| 25 | 
         
            +
                "GptParamsConfig",
         
     | 
| 26 | 
         
             
                "Pano2MeshSRConfig",
         
     | 
| 27 | 
         
             
                "GsplatTrainConfig",
         
     | 
| 28 | 
         
             
            ]
         
     | 
| 29 | 
         | 
| 30 | 
         | 
| 31 | 
         
            +
            @dataclass
         
     | 
| 32 | 
         
            +
            class GptParamsConfig(DataClassJsonMixin):
         
     | 
| 33 | 
         
            +
                temperature: float = 0.1
         
     | 
| 34 | 
         
            +
                top_p: float = 0.1
         
     | 
| 35 | 
         
            +
                frequency_penalty: float = 0.0
         
     | 
| 36 | 
         
            +
                presence_penalty: float = 0.0
         
     | 
| 37 | 
         
            +
                stop: int | None = None
         
     | 
| 38 | 
         
            +
                max_tokens: int = 500
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
             
            @dataclass
         
     | 
| 42 | 
         
             
            class Pano2MeshSRConfig:
         
     | 
| 43 | 
         
             
                mesh_file: str = "mesh_model.ply"
         
     | 
    	
        embodied_gen/utils/enum.py
    CHANGED
    
    | 
         @@ -102,6 +102,7 @@ class LayoutInfo(DataClassJsonMixin): 
     | 
|
| 102 | 
         
             
                tree: dict[str, list]
         
     | 
| 103 | 
         
             
                relation: dict[str, str | list[str]]
         
     | 
| 104 | 
         
             
                objs_desc: dict[str, str] = field(default_factory=dict)
         
     | 
| 
         | 
|
| 105 | 
         
             
                assets: dict[str, str] = field(default_factory=dict)
         
     | 
| 106 | 
         
             
                quality: dict[str, str] = field(default_factory=dict)
         
     | 
| 107 | 
         
             
                position: dict[str, list[float]] = field(default_factory=dict)
         
     | 
| 
         | 
|
| 102 | 
         
             
                tree: dict[str, list]
         
     | 
| 103 | 
         
             
                relation: dict[str, str | list[str]]
         
     | 
| 104 | 
         
             
                objs_desc: dict[str, str] = field(default_factory=dict)
         
     | 
| 105 | 
         
            +
                objs_mapping: dict[str, str] = field(default_factory=dict)
         
     | 
| 106 | 
         
             
                assets: dict[str, str] = field(default_factory=dict)
         
     | 
| 107 | 
         
             
                quality: dict[str, str] = field(default_factory=dict)
         
     | 
| 108 | 
         
             
                position: dict[str, list[float]] = field(default_factory=dict)
         
     | 
    	
        embodied_gen/utils/gaussian.py
    CHANGED
    
    | 
         @@ -35,7 +35,6 @@ __all__ = [ 
     | 
|
| 35 | 
         
             
                "set_random_seed",
         
     | 
| 36 | 
         
             
                "export_splats",
         
     | 
| 37 | 
         
             
                "create_splats_with_optimizers",
         
     | 
| 38 | 
         
            -
                "compute_pinhole_intrinsics",
         
     | 
| 39 | 
         
             
                "resize_pinhole_intrinsics",
         
     | 
| 40 | 
         
             
                "restore_scene_scale_and_position",
         
     | 
| 41 | 
         
             
            ]
         
     | 
| 
         @@ -265,12 +264,12 @@ def create_splats_with_optimizers( 
     | 
|
| 265 | 
         
             
                return splats, optimizers
         
     | 
| 266 | 
         | 
| 267 | 
         | 
| 268 | 
         
            -
            def  
     | 
| 269 | 
         
            -
                image_w: int, image_h: int,  
     | 
| 270 | 
         
             
            ) -> np.ndarray:
         
     | 
| 271 | 
         
            -
                 
     | 
| 272 | 
         
            -
                 
     | 
| 273 | 
         
            -
                 
     | 
| 274 | 
         
             
                cx = image_w / 2
         
     | 
| 275 | 
         
             
                cy = image_h / 2
         
     | 
| 276 | 
         
             
                K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
         
     | 
| 
         | 
|
| 35 | 
         
             
                "set_random_seed",
         
     | 
| 36 | 
         
             
                "export_splats",
         
     | 
| 37 | 
         
             
                "create_splats_with_optimizers",
         
     | 
| 
         | 
|
| 38 | 
         
             
                "resize_pinhole_intrinsics",
         
     | 
| 39 | 
         
             
                "restore_scene_scale_and_position",
         
     | 
| 40 | 
         
             
            ]
         
     | 
| 
         | 
|
| 264 | 
         
             
                return splats, optimizers
         
     | 
| 265 | 
         | 
| 266 | 
         | 
| 267 | 
         
            +
            def compute_intrinsics_from_fovy(
         
     | 
| 268 | 
         
            +
                image_w: int, image_h: int, fovy_deg: float
         
     | 
| 269 | 
         
             
            ) -> np.ndarray:
         
     | 
| 270 | 
         
            +
                fovy_rad = np.deg2rad(fovy_deg)
         
     | 
| 271 | 
         
            +
                fy = image_h / (2 * np.tan(fovy_rad / 2))
         
     | 
| 272 | 
         
            +
                fx = fy * (image_w / image_h)
         
     | 
| 273 | 
         
             
                cx = image_w / 2
         
     | 
| 274 | 
         
             
                cy = image_h / 2
         
     | 
| 275 | 
         
             
                K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
         
     | 
    	
        embodied_gen/utils/geometry.py
    ADDED
    
    | 
         @@ -0,0 +1,458 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Project EmbodiedGen
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 14 | 
         
            +
            # implied. See the License for the specific language governing
         
     | 
| 15 | 
         
            +
            # permissions and limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import os
         
     | 
| 18 | 
         
            +
            import random
         
     | 
| 19 | 
         
            +
            from collections import defaultdict, deque
         
     | 
| 20 | 
         
            +
            from functools import wraps
         
     | 
| 21 | 
         
            +
            from typing import Literal
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            import numpy as np
         
     | 
| 24 | 
         
            +
            import torch
         
     | 
| 25 | 
         
            +
            import trimesh
         
     | 
| 26 | 
         
            +
            from matplotlib.path import Path
         
     | 
| 27 | 
         
            +
            from pyquaternion import Quaternion
         
     | 
| 28 | 
         
            +
            from scipy.spatial import ConvexHull
         
     | 
| 29 | 
         
            +
            from scipy.spatial.transform import Rotation as R
         
     | 
| 30 | 
         
            +
            from shapely.geometry import Polygon
         
     | 
| 31 | 
         
            +
            from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
         
     | 
| 32 | 
         
            +
            from embodied_gen.utils.log import logger
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            __all__ = [
         
     | 
| 35 | 
         
            +
                "bfs_placement",
         
     | 
| 36 | 
         
            +
                "with_seed",
         
     | 
| 37 | 
         
            +
                "matrix_to_pose",
         
     | 
| 38 | 
         
            +
                "pose_to_matrix",
         
     | 
| 39 | 
         
            +
                "quaternion_multiply",
         
     | 
| 40 | 
         
            +
                "check_reachable",
         
     | 
| 41 | 
         
            +
                "bfs_placement",
         
     | 
| 42 | 
         
            +
                "compose_mesh_scene",
         
     | 
| 43 | 
         
            +
                "compute_pinhole_intrinsics",
         
     | 
| 44 | 
         
            +
            ]
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def matrix_to_pose(matrix: np.ndarray) -> list[float]:
         
     | 
| 48 | 
         
            +
                """Convert a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                Args:
         
     | 
| 51 | 
         
            +
                    matrix (np.ndarray): 4x4 transformation matrix.
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                Returns:
         
     | 
| 54 | 
         
            +
                    List[float]: Pose as [x, y, z, qx, qy, qz, qw].
         
     | 
| 55 | 
         
            +
                """
         
     | 
| 56 | 
         
            +
                x, y, z = matrix[:3, 3]
         
     | 
| 57 | 
         
            +
                rot_mat = matrix[:3, :3]
         
     | 
| 58 | 
         
            +
                quat = R.from_matrix(rot_mat).as_quat()
         
     | 
| 59 | 
         
            +
                qx, qy, qz, qw = quat
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                return [x, y, z, qx, qy, qz, qw]
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def pose_to_matrix(pose: list[float]) -> np.ndarray:
         
     | 
| 65 | 
         
            +
                """Convert pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                Args:
         
     | 
| 68 | 
         
            +
                    List[float]: Pose as [x, y, z, qx, qy, qz, qw].
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                Returns:
         
     | 
| 71 | 
         
            +
                    matrix (np.ndarray): 4x4 transformation matrix.
         
     | 
| 72 | 
         
            +
                """
         
     | 
| 73 | 
         
            +
                x, y, z, qx, qy, qz, qw = pose
         
     | 
| 74 | 
         
            +
                r = R.from_quat([qx, qy, qz, qw])
         
     | 
| 75 | 
         
            +
                matrix = np.eye(4)
         
     | 
| 76 | 
         
            +
                matrix[:3, :3] = r.as_matrix()
         
     | 
| 77 | 
         
            +
                matrix[:3, 3] = [x, y, z]
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                return matrix
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            def compute_xy_bbox(
         
     | 
| 83 | 
         
            +
                vertices: np.ndarray, col_x: int = 0, col_y: int = 2
         
     | 
| 84 | 
         
            +
            ) -> list[float]:
         
     | 
| 85 | 
         
            +
                x_vals = vertices[:, col_x]
         
     | 
| 86 | 
         
            +
                y_vals = vertices[:, col_y]
         
     | 
| 87 | 
         
            +
                return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def has_iou_conflict(
         
     | 
| 91 | 
         
            +
                new_box: list[float],
         
     | 
| 92 | 
         
            +
                placed_boxes: list[list[float]],
         
     | 
| 93 | 
         
            +
                iou_threshold: float = 0.0,
         
     | 
| 94 | 
         
            +
            ) -> bool:
         
     | 
| 95 | 
         
            +
                new_min_x, new_max_x, new_min_y, new_max_y = new_box
         
     | 
| 96 | 
         
            +
                for min_x, max_x, min_y, max_y in placed_boxes:
         
     | 
| 97 | 
         
            +
                    ix1 = max(new_min_x, min_x)
         
     | 
| 98 | 
         
            +
                    iy1 = max(new_min_y, min_y)
         
     | 
| 99 | 
         
            +
                    ix2 = min(new_max_x, max_x)
         
     | 
| 100 | 
         
            +
                    iy2 = min(new_max_y, max_y)
         
     | 
| 101 | 
         
            +
                    inter_area = max(0, ix2 - ix1) * max(0, iy2 - iy1)
         
     | 
| 102 | 
         
            +
                    if inter_area > iou_threshold:
         
     | 
| 103 | 
         
            +
                        return True
         
     | 
| 104 | 
         
            +
                return False
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            def with_seed(seed_attr_name: str = "seed"):
         
     | 
| 108 | 
         
            +
                """A parameterized decorator that temporarily sets the random seed."""
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                def decorator(func):
         
     | 
| 111 | 
         
            +
                    @wraps(func)
         
     | 
| 112 | 
         
            +
                    def wrapper(*args, **kwargs):
         
     | 
| 113 | 
         
            +
                        seed = kwargs.get(seed_attr_name, None)
         
     | 
| 114 | 
         
            +
                        if seed is not None:
         
     | 
| 115 | 
         
            +
                            py_state = random.getstate()
         
     | 
| 116 | 
         
            +
                            np_state = np.random.get_state()
         
     | 
| 117 | 
         
            +
                            torch_state = torch.get_rng_state()
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                            random.seed(seed)
         
     | 
| 120 | 
         
            +
                            np.random.seed(seed)
         
     | 
| 121 | 
         
            +
                            torch.manual_seed(seed)
         
     | 
| 122 | 
         
            +
                            try:
         
     | 
| 123 | 
         
            +
                                result = func(*args, **kwargs)
         
     | 
| 124 | 
         
            +
                            finally:
         
     | 
| 125 | 
         
            +
                                random.setstate(py_state)
         
     | 
| 126 | 
         
            +
                                np.random.set_state(np_state)
         
     | 
| 127 | 
         
            +
                                torch.set_rng_state(torch_state)
         
     | 
| 128 | 
         
            +
                            return result
         
     | 
| 129 | 
         
            +
                        else:
         
     | 
| 130 | 
         
            +
                            return func(*args, **kwargs)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    return wrapper
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                return decorator
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
            def compute_convex_hull_path(
         
     | 
| 138 | 
         
            +
                vertices: np.ndarray,
         
     | 
| 139 | 
         
            +
                z_threshold: float = 0.05,
         
     | 
| 140 | 
         
            +
                interp_per_edge: int = 3,
         
     | 
| 141 | 
         
            +
                margin: float = -0.02,
         
     | 
| 142 | 
         
            +
            ) -> Path:
         
     | 
| 143 | 
         
            +
                top_vertices = vertices[
         
     | 
| 144 | 
         
            +
                    vertices[:, 1] > vertices[:, 1].max() - z_threshold
         
     | 
| 145 | 
         
            +
                ]
         
     | 
| 146 | 
         
            +
                top_xy = top_vertices[:, [0, 2]]
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                if len(top_xy) < 3:
         
     | 
| 149 | 
         
            +
                    raise ValueError("Not enough points to form a convex hull")
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                hull = ConvexHull(top_xy)
         
     | 
| 152 | 
         
            +
                hull_points = top_xy[hull.vertices]
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                polygon = Polygon(hull_points)
         
     | 
| 155 | 
         
            +
                polygon = polygon.buffer(margin)
         
     | 
| 156 | 
         
            +
                hull_points = np.array(polygon.exterior.coords)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                dense_points = []
         
     | 
| 159 | 
         
            +
                for i in range(len(hull_points)):
         
     | 
| 160 | 
         
            +
                    p1 = hull_points[i]
         
     | 
| 161 | 
         
            +
                    p2 = hull_points[(i + 1) % len(hull_points)]
         
     | 
| 162 | 
         
            +
                    for t in np.linspace(0, 1, interp_per_edge, endpoint=False):
         
     | 
| 163 | 
         
            +
                        pt = (1 - t) * p1 + t * p2
         
     | 
| 164 | 
         
            +
                        dense_points.append(pt)
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                return Path(np.array(dense_points), closed=True)
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            def find_parent_node(node: str, tree: dict) -> str | None:
         
     | 
| 170 | 
         
            +
                for parent, children in tree.items():
         
     | 
| 171 | 
         
            +
                    if any(child[0] == node for child in children):
         
     | 
| 172 | 
         
            +
                        return parent
         
     | 
| 173 | 
         
            +
                return None
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
         
     | 
| 177 | 
         
            +
                x1, x2, y1, y2 = box
         
     | 
| 178 | 
         
            +
                corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                num_inside = sum(hull.contains_point(c) for c in corners)
         
     | 
| 181 | 
         
            +
                return num_inside >= threshold
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
            def compute_axis_rotation_quat(
         
     | 
| 185 | 
         
            +
                axis: Literal["x", "y", "z"], angle_rad: float
         
     | 
| 186 | 
         
            +
            ) -> list[float]:
         
     | 
| 187 | 
         
            +
                if axis.lower() == 'x':
         
     | 
| 188 | 
         
            +
                    q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
         
     | 
| 189 | 
         
            +
                elif axis.lower() == 'y':
         
     | 
| 190 | 
         
            +
                    q = Quaternion(axis=[0, 1, 0], angle=angle_rad)
         
     | 
| 191 | 
         
            +
                elif axis.lower() == 'z':
         
     | 
| 192 | 
         
            +
                    q = Quaternion(axis=[0, 0, 1], angle=angle_rad)
         
     | 
| 193 | 
         
            +
                else:
         
     | 
| 194 | 
         
            +
                    raise ValueError(f"Unsupported axis '{axis}', must be one of x, y, z")
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                return [q.x, q.y, q.z, q.w]
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
            def quaternion_multiply(
         
     | 
| 200 | 
         
            +
                init_quat: list[float], rotate_quat: list[float]
         
     | 
| 201 | 
         
            +
            ) -> list[float]:
         
     | 
| 202 | 
         
            +
                qx, qy, qz, qw = init_quat
         
     | 
| 203 | 
         
            +
                q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
         
     | 
| 204 | 
         
            +
                qx, qy, qz, qw = rotate_quat
         
     | 
| 205 | 
         
            +
                q2 = Quaternion(w=qw, x=qx, y=qy, z=qz)
         
     | 
| 206 | 
         
            +
                quat = q2 * q1
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                return [quat.x, quat.y, quat.z, quat.w]
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
            def check_reachable(
         
     | 
| 212 | 
         
            +
                base_xyz: np.ndarray,
         
     | 
| 213 | 
         
            +
                reach_xyz: np.ndarray,
         
     | 
| 214 | 
         
            +
                min_reach: float = 0.25,
         
     | 
| 215 | 
         
            +
                max_reach: float = 0.85,
         
     | 
| 216 | 
         
            +
            ) -> bool:
         
     | 
| 217 | 
         
            +
                """Check if the target point is within the reachable range."""
         
     | 
| 218 | 
         
            +
                distance = np.linalg.norm(reach_xyz - base_xyz)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                return min_reach < distance < max_reach
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
            @with_seed("seed")
         
     | 
| 224 | 
         
            +
            def bfs_placement(
         
     | 
| 225 | 
         
            +
                layout_info: LayoutInfo,
         
     | 
| 226 | 
         
            +
                floor_margin: float = 0,
         
     | 
| 227 | 
         
            +
                beside_margin: float = 0.1,
         
     | 
| 228 | 
         
            +
                max_attempts: int = 3000,
         
     | 
| 229 | 
         
            +
                rotate_objs: bool = True,
         
     | 
| 230 | 
         
            +
                rotate_bg: bool = True,
         
     | 
| 231 | 
         
            +
                limit_reach_range: bool = True,
         
     | 
| 232 | 
         
            +
                robot_dim: float = 0.12,
         
     | 
| 233 | 
         
            +
                seed: int = None,
         
     | 
| 234 | 
         
            +
            ) -> LayoutInfo:
         
     | 
| 235 | 
         
            +
                object_mapping = layout_info.objs_mapping
         
     | 
| 236 | 
         
            +
                position = {}  # node: [x, y, z, qx, qy, qz, qw]
         
     | 
| 237 | 
         
            +
                parent_bbox_xy = {}
         
     | 
| 238 | 
         
            +
                placed_boxes_map = defaultdict(list)
         
     | 
| 239 | 
         
            +
                mesh_info = defaultdict(dict)
         
     | 
| 240 | 
         
            +
                robot_node = layout_info.relation[Scene3DItemEnum.ROBOT.value]
         
     | 
| 241 | 
         
            +
                for node in object_mapping:
         
     | 
| 242 | 
         
            +
                    if object_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
         
     | 
| 243 | 
         
            +
                        bg_quat = (
         
     | 
| 244 | 
         
            +
                            compute_axis_rotation_quat(
         
     | 
| 245 | 
         
            +
                                axis="y",
         
     | 
| 246 | 
         
            +
                                angle_rad=np.random.uniform(0, 2 * np.pi),
         
     | 
| 247 | 
         
            +
                            )
         
     | 
| 248 | 
         
            +
                            if rotate_bg
         
     | 
| 249 | 
         
            +
                            else [0, 0, 0, 1]
         
     | 
| 250 | 
         
            +
                        )
         
     | 
| 251 | 
         
            +
                        bg_quat = [round(q, 4) for q in bg_quat]
         
     | 
| 252 | 
         
            +
                        continue
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                    mesh_path = (
         
     | 
| 255 | 
         
            +
                        f"{layout_info.assets[node]}/mesh/{node.replace(' ', '_')}.obj"
         
     | 
| 256 | 
         
            +
                    )
         
     | 
| 257 | 
         
            +
                    mesh_info[node]["path"] = mesh_path
         
     | 
| 258 | 
         
            +
                    mesh = trimesh.load(mesh_path)
         
     | 
| 259 | 
         
            +
                    vertices = mesh.vertices
         
     | 
| 260 | 
         
            +
                    z1 = np.percentile(vertices[:, 1], 1)
         
     | 
| 261 | 
         
            +
                    z2 = np.percentile(vertices[:, 1], 99)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    if object_mapping[node] == Scene3DItemEnum.CONTEXT.value:
         
     | 
| 264 | 
         
            +
                        object_quat = [0, 0, 0, 1]
         
     | 
| 265 | 
         
            +
                        mesh_info[node]["surface"] = compute_convex_hull_path(vertices)
         
     | 
| 266 | 
         
            +
                        # Put robot in the CONTEXT edge.
         
     | 
| 267 | 
         
            +
                        x, y = random.choice(mesh_info[node]["surface"].vertices)
         
     | 
| 268 | 
         
            +
                        theta = np.arctan2(y, x)
         
     | 
| 269 | 
         
            +
                        quat_initial = Quaternion(axis=[0, 0, 1], angle=theta)
         
     | 
| 270 | 
         
            +
                        quat_extra = Quaternion(axis=[0, 0, 1], angle=np.pi)
         
     | 
| 271 | 
         
            +
                        quat = quat_extra * quat_initial
         
     | 
| 272 | 
         
            +
                        _pose = [x, y, z2 - z1, quat.x, quat.y, quat.z, quat.w]
         
     | 
| 273 | 
         
            +
                        position[robot_node] = [round(v, 4) for v in _pose]
         
     | 
| 274 | 
         
            +
                        node_box = [
         
     | 
| 275 | 
         
            +
                            x - robot_dim / 2,
         
     | 
| 276 | 
         
            +
                            x + robot_dim / 2,
         
     | 
| 277 | 
         
            +
                            y - robot_dim / 2,
         
     | 
| 278 | 
         
            +
                            y + robot_dim / 2,
         
     | 
| 279 | 
         
            +
                        ]
         
     | 
| 280 | 
         
            +
                        placed_boxes_map[node].append(node_box)
         
     | 
| 281 | 
         
            +
                    elif rotate_objs:
         
     | 
| 282 | 
         
            +
                        # For manipulated and distractor objects, apply random rotation
         
     | 
| 283 | 
         
            +
                        angle_rad = np.random.uniform(0, 2 * np.pi)
         
     | 
| 284 | 
         
            +
                        object_quat = compute_axis_rotation_quat(
         
     | 
| 285 | 
         
            +
                            axis="y", angle_rad=angle_rad
         
     | 
| 286 | 
         
            +
                        )
         
     | 
| 287 | 
         
            +
                        object_quat_scipy = np.roll(object_quat, 1)  # [w, x, y, z]
         
     | 
| 288 | 
         
            +
                        rotation = R.from_quat(object_quat_scipy).as_matrix()
         
     | 
| 289 | 
         
            +
                        vertices = np.dot(mesh.vertices, rotation.T)
         
     | 
| 290 | 
         
            +
                        z1 = np.percentile(vertices[:, 1], 1)
         
     | 
| 291 | 
         
            +
                        z2 = np.percentile(vertices[:, 1], 99)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                    x1, x2, y1, y2 = compute_xy_bbox(vertices)
         
     | 
| 294 | 
         
            +
                    mesh_info[node]["pose"] = [x1, x2, y1, y2, z1, z2, *object_quat]
         
     | 
| 295 | 
         
            +
                    mesh_info[node]["area"] = max(1e-5, (x2 - x1) * (y2 - y1))
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                root = list(layout_info.tree.keys())[0]
         
     | 
| 298 | 
         
            +
                queue = deque([((root, None), layout_info.tree.get(root, []))])
         
     | 
| 299 | 
         
            +
                while queue:
         
     | 
| 300 | 
         
            +
                    (node, relation), children = queue.popleft()
         
     | 
| 301 | 
         
            +
                    if node not in object_mapping:
         
     | 
| 302 | 
         
            +
                        continue
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    if object_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
         
     | 
| 305 | 
         
            +
                        position[node] = [0, 0, floor_margin, *bg_quat]
         
     | 
| 306 | 
         
            +
                    else:
         
     | 
| 307 | 
         
            +
                        x1, x2, y1, y2, z1, z2, qx, qy, qz, qw = mesh_info[node]["pose"]
         
     | 
| 308 | 
         
            +
                        if object_mapping[node] == Scene3DItemEnum.CONTEXT.value:
         
     | 
| 309 | 
         
            +
                            position[node] = [0, 0, -round(z1, 4), qx, qy, qz, qw]
         
     | 
| 310 | 
         
            +
                            parent_bbox_xy[node] = [x1, x2, y1, y2, z1, z2]
         
     | 
| 311 | 
         
            +
                        elif object_mapping[node] in [
         
     | 
| 312 | 
         
            +
                            Scene3DItemEnum.MANIPULATED_OBJS.value,
         
     | 
| 313 | 
         
            +
                            Scene3DItemEnum.DISTRACTOR_OBJS.value,
         
     | 
| 314 | 
         
            +
                        ]:
         
     | 
| 315 | 
         
            +
                            parent_node = find_parent_node(node, layout_info.tree)
         
     | 
| 316 | 
         
            +
                            parent_pos = position[parent_node]
         
     | 
| 317 | 
         
            +
                            (
         
     | 
| 318 | 
         
            +
                                p_x1,
         
     | 
| 319 | 
         
            +
                                p_x2,
         
     | 
| 320 | 
         
            +
                                p_y1,
         
     | 
| 321 | 
         
            +
                                p_y2,
         
     | 
| 322 | 
         
            +
                                p_z1,
         
     | 
| 323 | 
         
            +
                                p_z2,
         
     | 
| 324 | 
         
            +
                            ) = parent_bbox_xy[parent_node]
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                            obj_dx = x2 - x1
         
     | 
| 327 | 
         
            +
                            obj_dy = y2 - y1
         
     | 
| 328 | 
         
            +
                            hull_path = mesh_info[parent_node].get("surface")
         
     | 
| 329 | 
         
            +
                            for _ in range(max_attempts):
         
     | 
| 330 | 
         
            +
                                node_x1 = random.uniform(p_x1, p_x2 - obj_dx)
         
     | 
| 331 | 
         
            +
                                node_y1 = random.uniform(p_y1, p_y2 - obj_dy)
         
     | 
| 332 | 
         
            +
                                node_box = [
         
     | 
| 333 | 
         
            +
                                    node_x1,
         
     | 
| 334 | 
         
            +
                                    node_x1 + obj_dx,
         
     | 
| 335 | 
         
            +
                                    node_y1,
         
     | 
| 336 | 
         
            +
                                    node_y1 + obj_dy,
         
     | 
| 337 | 
         
            +
                                ]
         
     | 
| 338 | 
         
            +
                                if hull_path and not all_corners_inside(
         
     | 
| 339 | 
         
            +
                                    hull_path, node_box
         
     | 
| 340 | 
         
            +
                                ):
         
     | 
| 341 | 
         
            +
                                    continue
         
     | 
| 342 | 
         
            +
                                # Make sure the manipulated object is reachable by robot.
         
     | 
| 343 | 
         
            +
                                if (
         
     | 
| 344 | 
         
            +
                                    limit_reach_range
         
     | 
| 345 | 
         
            +
                                    and object_mapping[node]
         
     | 
| 346 | 
         
            +
                                    == Scene3DItemEnum.MANIPULATED_OBJS.value
         
     | 
| 347 | 
         
            +
                                ):
         
     | 
| 348 | 
         
            +
                                    cx = parent_pos[0] + node_box[0] + obj_dx / 2
         
     | 
| 349 | 
         
            +
                                    cy = parent_pos[1] + node_box[2] + obj_dy / 2
         
     | 
| 350 | 
         
            +
                                    cz = parent_pos[2] + p_z2 - z1
         
     | 
| 351 | 
         
            +
                                    robot_pose = position[robot_node][:3]
         
     | 
| 352 | 
         
            +
                                    if not check_reachable(
         
     | 
| 353 | 
         
            +
                                        base_xyz=np.array(robot_pose),
         
     | 
| 354 | 
         
            +
                                        reach_xyz=np.array([cx, cy, cz]),
         
     | 
| 355 | 
         
            +
                                    ):
         
     | 
| 356 | 
         
            +
                                        continue
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                                if not has_iou_conflict(
         
     | 
| 359 | 
         
            +
                                    node_box, placed_boxes_map[parent_node]
         
     | 
| 360 | 
         
            +
                                ):
         
     | 
| 361 | 
         
            +
                                    z_offset = 0
         
     | 
| 362 | 
         
            +
                                    break
         
     | 
| 363 | 
         
            +
                            else:
         
     | 
| 364 | 
         
            +
                                logger.warning(
         
     | 
| 365 | 
         
            +
                                    f"Cannot place {node} on {parent_node} without overlap"
         
     | 
| 366 | 
         
            +
                                    f" after {max_attempts} attempts, place beside {parent_node}."
         
     | 
| 367 | 
         
            +
                                )
         
     | 
| 368 | 
         
            +
                                for _ in range(max_attempts):
         
     | 
| 369 | 
         
            +
                                    node_x1 = random.choice(
         
     | 
| 370 | 
         
            +
                                        [
         
     | 
| 371 | 
         
            +
                                            random.uniform(
         
     | 
| 372 | 
         
            +
                                                p_x1 - obj_dx - beside_margin,
         
     | 
| 373 | 
         
            +
                                                p_x1 - obj_dx,
         
     | 
| 374 | 
         
            +
                                            ),
         
     | 
| 375 | 
         
            +
                                            random.uniform(p_x2, p_x2 + beside_margin),
         
     | 
| 376 | 
         
            +
                                        ]
         
     | 
| 377 | 
         
            +
                                    )
         
     | 
| 378 | 
         
            +
                                    node_y1 = random.choice(
         
     | 
| 379 | 
         
            +
                                        [
         
     | 
| 380 | 
         
            +
                                            random.uniform(
         
     | 
| 381 | 
         
            +
                                                p_y1 - obj_dy - beside_margin,
         
     | 
| 382 | 
         
            +
                                                p_y1 - obj_dy,
         
     | 
| 383 | 
         
            +
                                            ),
         
     | 
| 384 | 
         
            +
                                            random.uniform(p_y2, p_y2 + beside_margin),
         
     | 
| 385 | 
         
            +
                                        ]
         
     | 
| 386 | 
         
            +
                                    )
         
     | 
| 387 | 
         
            +
                                    node_box = [
         
     | 
| 388 | 
         
            +
                                        node_x1,
         
     | 
| 389 | 
         
            +
                                        node_x1 + obj_dx,
         
     | 
| 390 | 
         
            +
                                        node_y1,
         
     | 
| 391 | 
         
            +
                                        node_y1 + obj_dy,
         
     | 
| 392 | 
         
            +
                                    ]
         
     | 
| 393 | 
         
            +
                                    z_offset = -(parent_pos[2] + p_z2)
         
     | 
| 394 | 
         
            +
                                    if not has_iou_conflict(
         
     | 
| 395 | 
         
            +
                                        node_box, placed_boxes_map[parent_node]
         
     | 
| 396 | 
         
            +
                                    ):
         
     | 
| 397 | 
         
            +
                                        break
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                            placed_boxes_map[parent_node].append(node_box)
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                            abs_cx = parent_pos[0] + node_box[0] + obj_dx / 2
         
     | 
| 402 | 
         
            +
                            abs_cy = parent_pos[1] + node_box[2] + obj_dy / 2
         
     | 
| 403 | 
         
            +
                            abs_cz = parent_pos[2] + p_z2 - z1 + z_offset
         
     | 
| 404 | 
         
            +
                            position[node] = [
         
     | 
| 405 | 
         
            +
                                round(v, 4)
         
     | 
| 406 | 
         
            +
                                for v in [abs_cx, abs_cy, abs_cz, qx, qy, qz, qw]
         
     | 
| 407 | 
         
            +
                            ]
         
     | 
| 408 | 
         
            +
                            parent_bbox_xy[node] = [x1, x2, y1, y2, z1, z2]
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                    sorted_children = sorted(
         
     | 
| 411 | 
         
            +
                        children, key=lambda x: -mesh_info[x[0]].get("area", 0)
         
     | 
| 412 | 
         
            +
                    )
         
     | 
| 413 | 
         
            +
                    for child, rel in sorted_children:
         
     | 
| 414 | 
         
            +
                        queue.append(((child, rel), layout_info.tree.get(child, [])))
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                layout_info.position = position
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                return layout_info
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
            def compose_mesh_scene(
         
     | 
| 422 | 
         
            +
                layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
         
     | 
| 423 | 
         
            +
            ) -> None:
         
     | 
| 424 | 
         
            +
                object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
         
     | 
| 425 | 
         
            +
                scene = trimesh.Scene()
         
     | 
| 426 | 
         
            +
                for node in layout_info.assets:
         
     | 
| 427 | 
         
            +
                    if object_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
         
     | 
| 428 | 
         
            +
                        mesh_path = f"{layout_info.assets[node]}/mesh_model.ply"
         
     | 
| 429 | 
         
            +
                        if not with_bg:
         
     | 
| 430 | 
         
            +
                            continue
         
     | 
| 431 | 
         
            +
                    else:
         
     | 
| 432 | 
         
            +
                        mesh_path = (
         
     | 
| 433 | 
         
            +
                            f"{layout_info.assets[node]}/mesh/{node.replace(' ', '_')}.obj"
         
     | 
| 434 | 
         
            +
                        )
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
                    mesh = trimesh.load(mesh_path)
         
     | 
| 437 | 
         
            +
                    offset = np.array(layout_info.position[node])[[0, 2, 1]]
         
     | 
| 438 | 
         
            +
                    mesh.vertices += offset
         
     | 
| 439 | 
         
            +
                    scene.add_geometry(mesh, node_name=node)
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
                os.makedirs(os.path.dirname(out_scene_path), exist_ok=True)
         
     | 
| 442 | 
         
            +
                scene.export(out_scene_path)
         
     | 
| 443 | 
         
            +
                logger.info(f"Composed interactive 3D layout saved in {out_scene_path}")
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                return
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
            def compute_pinhole_intrinsics(
         
     | 
| 449 | 
         
            +
                image_w: int, image_h: int, fov_deg: float
         
     | 
| 450 | 
         
            +
            ) -> np.ndarray:
         
     | 
| 451 | 
         
            +
                fov_rad = np.deg2rad(fov_deg)
         
     | 
| 452 | 
         
            +
                fx = image_w / (2 * np.tan(fov_rad / 2))
         
     | 
| 453 | 
         
            +
                fy = fx  # assuming square pixels
         
     | 
| 454 | 
         
            +
                cx = image_w / 2
         
     | 
| 455 | 
         
            +
                cy = image_h / 2
         
     | 
| 456 | 
         
            +
                K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                return K
         
     | 
    	
        embodied_gen/utils/monkey_patches.py
    CHANGED
    
    | 
         @@ -18,6 +18,7 @@ import os 
     | 
|
| 18 | 
         
             
            import sys
         
     | 
| 19 | 
         
             
            import zipfile
         
     | 
| 20 | 
         | 
| 
         | 
|
| 21 | 
         
             
            import torch
         
     | 
| 22 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 23 | 
         
             
            from omegaconf import OmegaConf
         
     | 
| 
         @@ -150,3 +151,68 @@ def monkey_patch_pano2room(): 
     | 
|
| 150 | 
         
             
                    self.inpaint_pipe = pipe
         
     | 
| 151 | 
         | 
| 152 | 
         
             
                SDFTInpainter.__init__ = patched_sd_inpaint_init
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 18 | 
         
             
            import sys
         
     | 
| 19 | 
         
             
            import zipfile
         
     | 
| 20 | 
         | 
| 21 | 
         
            +
            import numpy as np
         
     | 
| 22 | 
         
             
            import torch
         
     | 
| 23 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 24 | 
         
             
            from omegaconf import OmegaConf
         
     | 
| 
         | 
|
| 151 | 
         
             
                    self.inpaint_pipe = pipe
         
     | 
| 152 | 
         | 
| 153 | 
         
             
                SDFTInpainter.__init__ = patched_sd_inpaint_init
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            def monkey_patch_maniskill():
         
     | 
| 157 | 
         
            +
                from mani_skill.envs.scene import ManiSkillScene
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                def get_sensor_images(
         
     | 
| 160 | 
         
            +
                    self, obs: dict[str, any]
         
     | 
| 161 | 
         
            +
                ) -> dict[str, dict[str, torch.Tensor]]:
         
     | 
| 162 | 
         
            +
                    sensor_data = dict()
         
     | 
| 163 | 
         
            +
                    for name, sensor in self.sensors.items():
         
     | 
| 164 | 
         
            +
                        sensor_data[name] = sensor.get_images(obs[name])
         
     | 
| 165 | 
         
            +
                    return sensor_data
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                def get_human_render_camera_images(
         
     | 
| 168 | 
         
            +
                    self, camera_name: str = None, return_alpha: bool = False
         
     | 
| 169 | 
         
            +
                ) -> dict[str, torch.Tensor]:
         
     | 
| 170 | 
         
            +
                    def get_rgba_tensor(camera, return_alpha):
         
     | 
| 171 | 
         
            +
                        color = camera.get_obs(
         
     | 
| 172 | 
         
            +
                            rgb=True, depth=False, segmentation=False, position=False
         
     | 
| 173 | 
         
            +
                        )["rgb"]
         
     | 
| 174 | 
         
            +
                        if return_alpha:
         
     | 
| 175 | 
         
            +
                            seg_labels = camera.get_obs(
         
     | 
| 176 | 
         
            +
                                rgb=False, depth=False, segmentation=True, position=False
         
     | 
| 177 | 
         
            +
                            )["segmentation"]
         
     | 
| 178 | 
         
            +
                            masks = np.where((seg_labels.cpu() > 0), 255, 0).astype(
         
     | 
| 179 | 
         
            +
                                np.uint8
         
     | 
| 180 | 
         
            +
                            )
         
     | 
| 181 | 
         
            +
                            masks = torch.tensor(masks).to(color.device)
         
     | 
| 182 | 
         
            +
                            color = torch.concat([color, masks], dim=-1)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                        return color
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    image_data = dict()
         
     | 
| 187 | 
         
            +
                    if self.gpu_sim_enabled:
         
     | 
| 188 | 
         
            +
                        if self.parallel_in_single_scene:
         
     | 
| 189 | 
         
            +
                            for name, camera in self.human_render_cameras.items():
         
     | 
| 190 | 
         
            +
                                camera.camera._render_cameras[0].take_picture()
         
     | 
| 191 | 
         
            +
                                rgba = get_rgba_tensor(camera, return_alpha)
         
     | 
| 192 | 
         
            +
                                image_data[name] = rgba
         
     | 
| 193 | 
         
            +
                        else:
         
     | 
| 194 | 
         
            +
                            for name, camera in self.human_render_cameras.items():
         
     | 
| 195 | 
         
            +
                                if camera_name is not None and name != camera_name:
         
     | 
| 196 | 
         
            +
                                    continue
         
     | 
| 197 | 
         
            +
                                assert camera.config.shader_config.shader_pack not in [
         
     | 
| 198 | 
         
            +
                                    "rt",
         
     | 
| 199 | 
         
            +
                                    "rt-fast",
         
     | 
| 200 | 
         
            +
                                    "rt-med",
         
     | 
| 201 | 
         
            +
                                ], "ray tracing shaders do not work with parallel rendering"
         
     | 
| 202 | 
         
            +
                                camera.capture()
         
     | 
| 203 | 
         
            +
                                rgba = get_rgba_tensor(camera, return_alpha)
         
     | 
| 204 | 
         
            +
                                image_data[name] = rgba
         
     | 
| 205 | 
         
            +
                    else:
         
     | 
| 206 | 
         
            +
                        for name, camera in self.human_render_cameras.items():
         
     | 
| 207 | 
         
            +
                            if camera_name is not None and name != camera_name:
         
     | 
| 208 | 
         
            +
                                continue
         
     | 
| 209 | 
         
            +
                            camera.capture()
         
     | 
| 210 | 
         
            +
                            rgba = get_rgba_tensor(camera, return_alpha)
         
     | 
| 211 | 
         
            +
                            image_data[name] = rgba
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                    return image_data
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                ManiSkillScene.get_sensor_images = get_sensor_images
         
     | 
| 216 | 
         
            +
                ManiSkillScene.get_human_render_camera_images = (
         
     | 
| 217 | 
         
            +
                    get_human_render_camera_images
         
     | 
| 218 | 
         
            +
                )
         
     | 
    	
        embodied_gen/utils/process_media.py
    CHANGED
    
    | 
         @@ -166,7 +166,7 @@ def combine_images_to_grid( 
     | 
|
| 166 | 
         
             
                images: list[str | Image.Image],
         
     | 
| 167 | 
         
             
                cat_row_col: tuple[int, int] = None,
         
     | 
| 168 | 
         
             
                target_wh: tuple[int, int] = (512, 512),
         
     | 
| 169 | 
         
            -
            ) -> list[ 
     | 
| 170 | 
         
             
                n_images = len(images)
         
     | 
| 171 | 
         
             
                if n_images == 1:
         
     | 
| 172 | 
         
             
                    return images
         
     | 
| 
         @@ -377,6 +377,42 @@ def parse_text_prompts(prompts: list[str]) -> list[str]: 
     | 
|
| 377 | 
         
             
                return prompts
         
     | 
| 378 | 
         | 
| 379 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 380 | 
         
             
            def check_object_edge_truncated(
         
     | 
| 381 | 
         
             
                mask: np.ndarray, edge_threshold: int = 5
         
     | 
| 382 | 
         
             
            ) -> bool:
         
     | 
| 
         @@ -400,8 +436,15 @@ def check_object_edge_truncated( 
     | 
|
| 400 | 
         | 
| 401 | 
         | 
| 402 | 
         
             
            if __name__ == "__main__":
         
     | 
| 403 | 
         
            -
                 
     | 
| 404 | 
         
            -
                    "outputs/ 
     | 
| 405 | 
         
            -
                    "outputs/ 
     | 
| 406 | 
         
            -
                    " 
     | 
| 407 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 166 | 
         
             
                images: list[str | Image.Image],
         
     | 
| 167 | 
         
             
                cat_row_col: tuple[int, int] = None,
         
     | 
| 168 | 
         
             
                target_wh: tuple[int, int] = (512, 512),
         
     | 
| 169 | 
         
            +
            ) -> list[Image.Image]:
         
     | 
| 170 | 
         
             
                n_images = len(images)
         
     | 
| 171 | 
         
             
                if n_images == 1:
         
     | 
| 172 | 
         
             
                    return images
         
     | 
| 
         | 
|
| 377 | 
         
             
                return prompts
         
     | 
| 378 | 
         | 
| 379 | 
         | 
| 380 | 
         
            +
            def alpha_blend_rgba(
         
     | 
| 381 | 
         
            +
                fg_image: Union[str, Image.Image, np.ndarray],
         
     | 
| 382 | 
         
            +
                bg_image: Union[str, Image.Image, np.ndarray],
         
     | 
| 383 | 
         
            +
            ) -> Image.Image:
         
     | 
| 384 | 
         
            +
                """Alpha blends a foreground RGBA image over a background RGBA image.
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                Args:
         
     | 
| 387 | 
         
            +
                    fg_image: Foreground image. Can be a file path (str), a PIL Image,
         
     | 
| 388 | 
         
            +
                        or a NumPy ndarray.
         
     | 
| 389 | 
         
            +
                    bg_image: Background image. Can be a file path (str), a PIL Image,
         
     | 
| 390 | 
         
            +
                        or a NumPy ndarray.
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                Returns:
         
     | 
| 393 | 
         
            +
                    A PIL Image representing the alpha-blended result in RGBA mode.
         
     | 
| 394 | 
         
            +
                """
         
     | 
| 395 | 
         
            +
                if isinstance(fg_image, str):
         
     | 
| 396 | 
         
            +
                    fg_image = Image.open(fg_image)
         
     | 
| 397 | 
         
            +
                elif isinstance(fg_image, np.ndarray):
         
     | 
| 398 | 
         
            +
                    fg_image = Image.fromarray(fg_image)
         
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
                if isinstance(bg_image, str):
         
     | 
| 401 | 
         
            +
                    bg_image = Image.open(bg_image)
         
     | 
| 402 | 
         
            +
                elif isinstance(bg_image, np.ndarray):
         
     | 
| 403 | 
         
            +
                    bg_image = Image.fromarray(bg_image)
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                if fg_image.size != bg_image.size:
         
     | 
| 406 | 
         
            +
                    raise ValueError(
         
     | 
| 407 | 
         
            +
                        f"Image sizes not match {fg_image.size} v.s. {bg_image.size}."
         
     | 
| 408 | 
         
            +
                    )
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                fg = fg_image.convert("RGBA")
         
     | 
| 411 | 
         
            +
                bg = bg_image.convert("RGBA")
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                return Image.alpha_composite(bg, fg)
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
             
            def check_object_edge_truncated(
         
     | 
| 417 | 
         
             
                mask: np.ndarray, edge_threshold: int = 5
         
     | 
| 418 | 
         
             
            ) -> bool:
         
     | 
| 
         | 
|
| 436 | 
         | 
| 437 | 
         | 
| 438 | 
         
             
            if __name__ == "__main__":
         
     | 
| 439 | 
         
            +
                image_paths = [
         
     | 
| 440 | 
         
            +
                    "outputs/layouts_sim/task_0000/images/pen.png",
         
     | 
| 441 | 
         
            +
                    "outputs/layouts_sim/task_0000/images/notebook.png",
         
     | 
| 442 | 
         
            +
                    "outputs/layouts_sim/task_0000/images/mug.png",
         
     | 
| 443 | 
         
            +
                    "outputs/layouts_sim/task_0000/images/lamp.png",
         
     | 
| 444 | 
         
            +
                    "outputs/layouts_sim2/task_0014/images/cloth.png",  # TODO
         
     | 
| 445 | 
         
            +
                ]
         
     | 
| 446 | 
         
            +
                for image_path in image_paths:
         
     | 
| 447 | 
         
            +
                    image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
         
     | 
| 448 | 
         
            +
                    mask = image[..., -1]
         
     | 
| 449 | 
         
            +
                    flag = check_object_edge_truncated(mask)
         
     | 
| 450 | 
         
            +
                    print(flag, image_path)
         
     | 
    	
        embodied_gen/utils/simulation.py
    ADDED
    
    | 
         @@ -0,0 +1,633 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Project EmbodiedGen
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 14 | 
         
            +
            # implied. See the License for the specific language governing
         
     | 
| 15 | 
         
            +
            # permissions and limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import json
         
     | 
| 18 | 
         
            +
            import logging
         
     | 
| 19 | 
         
            +
            import os
         
     | 
| 20 | 
         
            +
            import xml.etree.ElementTree as ET
         
     | 
| 21 | 
         
            +
            from collections import defaultdict
         
     | 
| 22 | 
         
            +
            from typing import Literal
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            import mplib
         
     | 
| 25 | 
         
            +
            import numpy as np
         
     | 
| 26 | 
         
            +
            import sapien.core as sapien
         
     | 
| 27 | 
         
            +
            import sapien.physx as physx
         
     | 
| 28 | 
         
            +
            import torch
         
     | 
| 29 | 
         
            +
            from mani_skill.agents.base_agent import BaseAgent
         
     | 
| 30 | 
         
            +
            from mani_skill.envs.scene import ManiSkillScene
         
     | 
| 31 | 
         
            +
            from mani_skill.examples.motionplanning.panda.utils import (
         
     | 
| 32 | 
         
            +
                compute_grasp_info_by_obb,
         
     | 
| 33 | 
         
            +
            )
         
     | 
| 34 | 
         
            +
            from mani_skill.utils.geometry.trimesh_utils import get_component_mesh
         
     | 
| 35 | 
         
            +
            from PIL import Image, ImageColor
         
     | 
| 36 | 
         
            +
            from scipy.spatial.transform import Rotation as R
         
     | 
| 37 | 
         
            +
            from embodied_gen.data.utils import DiffrastRender
         
     | 
| 38 | 
         
            +
            from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
         
     | 
| 39 | 
         
            +
            from embodied_gen.utils.geometry import quaternion_multiply
         
     | 
| 40 | 
         
            +
            from embodied_gen.utils.log import logger
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            COLORMAP = list(set(ImageColor.colormap.values()))
         
     | 
| 43 | 
         
            +
            COLOR_PALETTE = np.array(
         
     | 
| 44 | 
         
            +
                [ImageColor.getrgb(c) for c in COLORMAP], dtype=np.uint8
         
     | 
| 45 | 
         
            +
            )
         
     | 
| 46 | 
         
            +
            SIM_COORD_ALIGN = np.array(
         
     | 
| 47 | 
         
            +
                [
         
     | 
| 48 | 
         
            +
                    [1.0, 0.0, 0.0, 0.0],
         
     | 
| 49 | 
         
            +
                    [0.0, -1.0, 0.0, 0.0],
         
     | 
| 50 | 
         
            +
                    [0.0, 0.0, -1.0, 0.0],
         
     | 
| 51 | 
         
            +
                    [0.0, 0.0, 0.0, 1.0],
         
     | 
| 52 | 
         
            +
                ]
         
     | 
| 53 | 
         
            +
            )  # Used to align SAPIEN, MuJoCo coordinate system with the world coordinate system
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            __all__ = [
         
     | 
| 56 | 
         
            +
                "SIM_COORD_ALIGN",
         
     | 
| 57 | 
         
            +
                "FrankaPandaGrasper",
         
     | 
| 58 | 
         
            +
                "load_assets_from_layout_file",
         
     | 
| 59 | 
         
            +
                "load_mani_skill_robot",
         
     | 
| 60 | 
         
            +
                "render_images",
         
     | 
| 61 | 
         
            +
            ]
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def load_actor_from_urdf(
         
     | 
| 65 | 
         
            +
                scene: ManiSkillScene | sapien.Scene,
         
     | 
| 66 | 
         
            +
                file_path: str,
         
     | 
| 67 | 
         
            +
                pose: sapien.Pose,
         
     | 
| 68 | 
         
            +
                env_idx: int = None,
         
     | 
| 69 | 
         
            +
                use_static: bool = False,
         
     | 
| 70 | 
         
            +
                update_mass: bool = False,
         
     | 
| 71 | 
         
            +
            ) -> sapien.pysapien.Entity:
         
     | 
| 72 | 
         
            +
                tree = ET.parse(file_path)
         
     | 
| 73 | 
         
            +
                root = tree.getroot()
         
     | 
| 74 | 
         
            +
                node_name = root.get("name")
         
     | 
| 75 | 
         
            +
                file_dir = os.path.dirname(file_path)
         
     | 
| 76 | 
         
            +
                visual_file = root.find('.//visual/geometry/mesh').get("filename")
         
     | 
| 77 | 
         
            +
                collision_file = root.find('.//collision/geometry/mesh').get("filename")
         
     | 
| 78 | 
         
            +
                visual_file = os.path.join(file_dir, visual_file)
         
     | 
| 79 | 
         
            +
                collision_file = os.path.join(file_dir, collision_file)
         
     | 
| 80 | 
         
            +
                static_fric = root.find('.//collision/gazebo/mu1').text
         
     | 
| 81 | 
         
            +
                dynamic_fric = root.find('.//collision/gazebo/mu2').text
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                material = physx.PhysxMaterial(
         
     | 
| 84 | 
         
            +
                    static_friction=np.clip(float(static_fric), 0.1, 0.7),
         
     | 
| 85 | 
         
            +
                    dynamic_friction=np.clip(float(dynamic_fric), 0.1, 0.6),
         
     | 
| 86 | 
         
            +
                    restitution=0.05,
         
     | 
| 87 | 
         
            +
                )
         
     | 
| 88 | 
         
            +
                builder = scene.create_actor_builder()
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                body_type = "static" if use_static else "dynamic"
         
     | 
| 91 | 
         
            +
                builder.set_physx_body_type(body_type)
         
     | 
| 92 | 
         
            +
                builder.add_multiple_convex_collisions_from_file(
         
     | 
| 93 | 
         
            +
                    collision_file if body_type == "dynamic" else visual_file,
         
     | 
| 94 | 
         
            +
                    material=material,
         
     | 
| 95 | 
         
            +
                    # decomposition="coacd",
         
     | 
| 96 | 
         
            +
                    # decomposition_params=dict(
         
     | 
| 97 | 
         
            +
                    #     threshold=0.05, max_convex_hull=64, verbose=False
         
     | 
| 98 | 
         
            +
                    # ),
         
     | 
| 99 | 
         
            +
                )
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                builder.add_visual_from_file(visual_file)
         
     | 
| 102 | 
         
            +
                builder.set_initial_pose(pose)
         
     | 
| 103 | 
         
            +
                if isinstance(scene, ManiSkillScene) and env_idx is not None:
         
     | 
| 104 | 
         
            +
                    builder.set_scene_idxs([env_idx])
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                actor = builder.build(name=f"{node_name}-{env_idx}")
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                if update_mass and hasattr(actor.components[1], "mass"):
         
     | 
| 109 | 
         
            +
                    node_mass = float(root.find('.//inertial/mass').get("value"))
         
     | 
| 110 | 
         
            +
                    actor.components[1].set_mass(node_mass)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                return actor
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            def load_assets_from_layout_file(
         
     | 
| 116 | 
         
            +
                scene: ManiSkillScene | sapien.Scene,
         
     | 
| 117 | 
         
            +
                layout: LayoutInfo | str,
         
     | 
| 118 | 
         
            +
                z_offset: float = 0.0,
         
     | 
| 119 | 
         
            +
                init_quat: list[float] = [0, 0, 0, 1],
         
     | 
| 120 | 
         
            +
                env_idx: int = None,
         
     | 
| 121 | 
         
            +
            ) -> dict[str, sapien.pysapien.Entity]:
         
     | 
| 122 | 
         
            +
                """Load assets from `EmbodiedGen` layout-gen output and create actors in the scene.
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                Args:
         
     | 
| 125 | 
         
            +
                    scene (sapien.Scene | ManiSkillScene): The SAPIEN or ManiSkill scene to load assets into.
         
     | 
| 126 | 
         
            +
                    layout (LayoutInfo): The layout information data.
         
     | 
| 127 | 
         
            +
                    z_offset (float): Offset to apply to the Z-coordinate of non-context objects.
         
     | 
| 128 | 
         
            +
                    init_quat (List[float]): Initial quaternion (x, y, z, w) for orientation adjustment.
         
     | 
| 129 | 
         
            +
                    env_idx (int): Environment index for multi-environment setup.
         
     | 
| 130 | 
         
            +
                """
         
     | 
| 131 | 
         
            +
                if isinstance(layout, str) and layout.endswith(".json"):
         
     | 
| 132 | 
         
            +
                    layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                actors = dict()
         
     | 
| 135 | 
         
            +
                for node in layout.assets:
         
     | 
| 136 | 
         
            +
                    file_dir = layout.assets[node]
         
     | 
| 137 | 
         
            +
                    file_name = f"{node.replace(' ', '_')}.urdf"
         
     | 
| 138 | 
         
            +
                    urdf_file = os.path.join(file_dir, file_name)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    if layout.objs_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
         
     | 
| 141 | 
         
            +
                        continue
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    position = layout.position[node].copy()
         
     | 
| 144 | 
         
            +
                    if layout.objs_mapping[node] != Scene3DItemEnum.CONTEXT.value:
         
     | 
| 145 | 
         
            +
                        position[2] += z_offset
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                    use_static = (
         
     | 
| 148 | 
         
            +
                        layout.relation.get(Scene3DItemEnum.CONTEXT.value, None) == node
         
     | 
| 149 | 
         
            +
                    )
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    # Combine initial quaternion with object quaternion
         
     | 
| 152 | 
         
            +
                    x, y, z, qx, qy, qz, qw = position
         
     | 
| 153 | 
         
            +
                    qx, qy, qz, qw = quaternion_multiply([qx, qy, qz, qw], init_quat)
         
     | 
| 154 | 
         
            +
                    actor = load_actor_from_urdf(
         
     | 
| 155 | 
         
            +
                        scene,
         
     | 
| 156 | 
         
            +
                        urdf_file,
         
     | 
| 157 | 
         
            +
                        sapien.Pose(p=[x, y, z], q=[qw, qx, qy, qz]),
         
     | 
| 158 | 
         
            +
                        env_idx,
         
     | 
| 159 | 
         
            +
                        use_static=use_static,
         
     | 
| 160 | 
         
            +
                        update_mass=False,
         
     | 
| 161 | 
         
            +
                    )
         
     | 
| 162 | 
         
            +
                    actors[node] = actor
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                return actors
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
            def load_mani_skill_robot(
         
     | 
| 168 | 
         
            +
                scene: sapien.Scene | ManiSkillScene,
         
     | 
| 169 | 
         
            +
                layout: LayoutInfo | str,
         
     | 
| 170 | 
         
            +
                control_freq: int = 20,
         
     | 
| 171 | 
         
            +
                robot_init_qpos_noise: float = 0.0,
         
     | 
| 172 | 
         
            +
                control_mode: str = "pd_joint_pos",
         
     | 
| 173 | 
         
            +
                backend_str: tuple[str, str] = ("cpu", "gpu"),
         
     | 
| 174 | 
         
            +
            ) -> BaseAgent:
         
     | 
| 175 | 
         
            +
                from mani_skill.agents import REGISTERED_AGENTS
         
     | 
| 176 | 
         
            +
                from mani_skill.envs.scene import ManiSkillScene
         
     | 
| 177 | 
         
            +
                from mani_skill.envs.utils.system.backend import (
         
     | 
| 178 | 
         
            +
                    parse_sim_and_render_backend,
         
     | 
| 179 | 
         
            +
                )
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                if isinstance(layout, str) and layout.endswith(".json"):
         
     | 
| 182 | 
         
            +
                    layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                robot_name = layout.relation[Scene3DItemEnum.ROBOT.value]
         
     | 
| 185 | 
         
            +
                x, y, z, qx, qy, qz, qw = layout.position[robot_name]
         
     | 
| 186 | 
         
            +
                delta_z = 0.002  # Add small offset to avoid collision.
         
     | 
| 187 | 
         
            +
                pose = sapien.Pose([x, y, z + delta_z], [qw, qx, qy, qz])
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                if robot_name not in REGISTERED_AGENTS:
         
     | 
| 190 | 
         
            +
                    logger.warning(
         
     | 
| 191 | 
         
            +
                        f"Robot `{robot_name}` not registered, chosen from {REGISTERED_AGENTS.keys()}, use `panda` instead."
         
     | 
| 192 | 
         
            +
                    )
         
     | 
| 193 | 
         
            +
                    robot_name = "panda"
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                ROBOT_CLS = REGISTERED_AGENTS[robot_name].agent_cls
         
     | 
| 196 | 
         
            +
                backend = parse_sim_and_render_backend(*backend_str)
         
     | 
| 197 | 
         
            +
                if isinstance(scene, sapien.Scene):
         
     | 
| 198 | 
         
            +
                    scene = ManiSkillScene([scene], device=backend_str[0], backend=backend)
         
     | 
| 199 | 
         
            +
                robot = ROBOT_CLS(
         
     | 
| 200 | 
         
            +
                    scene=scene,
         
     | 
| 201 | 
         
            +
                    control_freq=control_freq,
         
     | 
| 202 | 
         
            +
                    control_mode=control_mode,
         
     | 
| 203 | 
         
            +
                    initial_pose=pose,
         
     | 
| 204 | 
         
            +
                )
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                # Set robot init joint rad agree(joint0 to joint6 w 2 finger).
         
     | 
| 207 | 
         
            +
                qpos = np.array(
         
     | 
| 208 | 
         
            +
                    [
         
     | 
| 209 | 
         
            +
                        0.0,
         
     | 
| 210 | 
         
            +
                        np.pi / 8,
         
     | 
| 211 | 
         
            +
                        0,
         
     | 
| 212 | 
         
            +
                        -np.pi * 3 / 8,
         
     | 
| 213 | 
         
            +
                        0,
         
     | 
| 214 | 
         
            +
                        np.pi * 3 / 4,
         
     | 
| 215 | 
         
            +
                        np.pi / 4,
         
     | 
| 216 | 
         
            +
                        0.04,
         
     | 
| 217 | 
         
            +
                        0.04,
         
     | 
| 218 | 
         
            +
                    ]
         
     | 
| 219 | 
         
            +
                )
         
     | 
| 220 | 
         
            +
                qpos = (
         
     | 
| 221 | 
         
            +
                    np.random.normal(
         
     | 
| 222 | 
         
            +
                        0, robot_init_qpos_noise, (len(scene.sub_scenes), len(qpos))
         
     | 
| 223 | 
         
            +
                    )
         
     | 
| 224 | 
         
            +
                    + qpos
         
     | 
| 225 | 
         
            +
                )
         
     | 
| 226 | 
         
            +
                qpos[:, -2:] = 0.04
         
     | 
| 227 | 
         
            +
                robot.reset(qpos)
         
     | 
| 228 | 
         
            +
                robot.init_qpos = robot.robot.qpos
         
     | 
| 229 | 
         
            +
                robot.controller.controllers["gripper"].reset()
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                return robot
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
            def render_images(
         
     | 
| 235 | 
         
            +
                camera: sapien.render.RenderCameraComponent,
         
     | 
| 236 | 
         
            +
                render_keys: list[
         
     | 
| 237 | 
         
            +
                    Literal[
         
     | 
| 238 | 
         
            +
                        "Color",
         
     | 
| 239 | 
         
            +
                        "Segmentation",
         
     | 
| 240 | 
         
            +
                        "Normal",
         
     | 
| 241 | 
         
            +
                        "Mask",
         
     | 
| 242 | 
         
            +
                        "Depth",
         
     | 
| 243 | 
         
            +
                        "Foreground",
         
     | 
| 244 | 
         
            +
                    ]
         
     | 
| 245 | 
         
            +
                ] = None,
         
     | 
| 246 | 
         
            +
            ) -> dict[str, Image.Image]:
         
     | 
| 247 | 
         
            +
                """Render images from a given sapien camera.
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                Args:
         
     | 
| 250 | 
         
            +
                    camera (sapien.render.RenderCameraComponent): The camera to render from.
         
     | 
| 251 | 
         
            +
                    render_keys (List[str]): Types of images to render (e.g., Color, Segmentation).
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                Returns:
         
     | 
| 254 | 
         
            +
                    Dict[str, Image.Image]: Dictionary of rendered images.
         
     | 
| 255 | 
         
            +
                """
         
     | 
| 256 | 
         
            +
                if render_keys is None:
         
     | 
| 257 | 
         
            +
                    render_keys = [
         
     | 
| 258 | 
         
            +
                        "Color",
         
     | 
| 259 | 
         
            +
                        "Segmentation",
         
     | 
| 260 | 
         
            +
                        "Normal",
         
     | 
| 261 | 
         
            +
                        "Mask",
         
     | 
| 262 | 
         
            +
                        "Depth",
         
     | 
| 263 | 
         
            +
                        "Foreground",
         
     | 
| 264 | 
         
            +
                    ]
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                results: dict[str, Image.Image] = {}
         
     | 
| 267 | 
         
            +
                if "Color" in render_keys:
         
     | 
| 268 | 
         
            +
                    color = camera.get_picture("Color")
         
     | 
| 269 | 
         
            +
                    color_rgb = (np.clip(color[..., :3], 0, 1) * 255).astype(np.uint8)
         
     | 
| 270 | 
         
            +
                    results["Color"] = Image.fromarray(color_rgb)
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                if "Mask" in render_keys:
         
     | 
| 273 | 
         
            +
                    alpha = (np.clip(color[..., 3], 0, 1) * 255).astype(np.uint8)
         
     | 
| 274 | 
         
            +
                    results["Mask"] = Image.fromarray(alpha)
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                if "Segmentation" in render_keys:
         
     | 
| 277 | 
         
            +
                    seg_labels = camera.get_picture("Segmentation")
         
     | 
| 278 | 
         
            +
                    label0 = seg_labels[..., 0].astype(np.uint8)
         
     | 
| 279 | 
         
            +
                    seg_color = COLOR_PALETTE[label0]
         
     | 
| 280 | 
         
            +
                    results["Segmentation"] = Image.fromarray(seg_color)
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                if "Foreground" in render_keys:
         
     | 
| 283 | 
         
            +
                    seg_labels = camera.get_picture("Segmentation")
         
     | 
| 284 | 
         
            +
                    label0 = seg_labels[..., 0]
         
     | 
| 285 | 
         
            +
                    mask = np.where((label0 > 1), 255, 0).astype(np.uint8)
         
     | 
| 286 | 
         
            +
                    color = camera.get_picture("Color")
         
     | 
| 287 | 
         
            +
                    color_rgb = (np.clip(color[..., :3], 0, 1) * 255).astype(np.uint8)
         
     | 
| 288 | 
         
            +
                    foreground = np.concatenate([color_rgb, mask[..., None]], axis=-1)
         
     | 
| 289 | 
         
            +
                    results["Foreground"] = Image.fromarray(foreground)
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                if "Normal" in render_keys:
         
     | 
| 292 | 
         
            +
                    normal = camera.get_picture("Normal")[..., :3]
         
     | 
| 293 | 
         
            +
                    normal_img = (((normal + 1) / 2) * 255).astype(np.uint8)
         
     | 
| 294 | 
         
            +
                    results["Normal"] = Image.fromarray(normal_img)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                if "Depth" in render_keys:
         
     | 
| 297 | 
         
            +
                    position_map = camera.get_picture("Position")
         
     | 
| 298 | 
         
            +
                    depth = -position_map[..., 2]
         
     | 
| 299 | 
         
            +
                    alpha = torch.tensor(color[..., 3], dtype=torch.float32)
         
     | 
| 300 | 
         
            +
                    norm_depth = DiffrastRender.normalize_map_by_mask(
         
     | 
| 301 | 
         
            +
                        torch.tensor(depth), alpha
         
     | 
| 302 | 
         
            +
                    )
         
     | 
| 303 | 
         
            +
                    depth_img = (norm_depth * 255).to(torch.uint8).numpy()
         
     | 
| 304 | 
         
            +
                    results["Depth"] = Image.fromarray(depth_img)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                return results
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
            class SapienSceneManager:
         
     | 
| 310 | 
         
            +
                """A class to manage SAPIEN simulator."""
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                def __init__(
         
     | 
| 313 | 
         
            +
                    self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
         
     | 
| 314 | 
         
            +
                ) -> None:
         
     | 
| 315 | 
         
            +
                    self.sim_freq = sim_freq
         
     | 
| 316 | 
         
            +
                    self.ray_tracing = ray_tracing
         
     | 
| 317 | 
         
            +
                    self.device = device
         
     | 
| 318 | 
         
            +
                    self.renderer = sapien.SapienRenderer()
         
     | 
| 319 | 
         
            +
                    self.scene = self._setup_scene()
         
     | 
| 320 | 
         
            +
                    self.cameras: list[sapien.render.RenderCameraComponent] = []
         
     | 
| 321 | 
         
            +
                    self.actors: dict[str, sapien.pysapien.Entity] = {}
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                def _setup_scene(self) -> sapien.Scene:
         
     | 
| 324 | 
         
            +
                    """Set up the SAPIEN scene with lighting and ground."""
         
     | 
| 325 | 
         
            +
                    # Ray tracing settings
         
     | 
| 326 | 
         
            +
                    if self.ray_tracing:
         
     | 
| 327 | 
         
            +
                        sapien.render.set_camera_shader_dir("rt")
         
     | 
| 328 | 
         
            +
                        sapien.render.set_ray_tracing_samples_per_pixel(64)
         
     | 
| 329 | 
         
            +
                        sapien.render.set_ray_tracing_path_depth(10)
         
     | 
| 330 | 
         
            +
                        sapien.render.set_ray_tracing_denoiser("oidn")
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    scene = sapien.Scene()
         
     | 
| 333 | 
         
            +
                    scene.set_timestep(1 / self.sim_freq)
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    # Add lighting
         
     | 
| 336 | 
         
            +
                    scene.set_ambient_light([0.2, 0.2, 0.2])
         
     | 
| 337 | 
         
            +
                    scene.add_directional_light(
         
     | 
| 338 | 
         
            +
                        direction=[0, 1, -1],
         
     | 
| 339 | 
         
            +
                        color=[1.5, 1.45, 1.4],
         
     | 
| 340 | 
         
            +
                        shadow=True,
         
     | 
| 341 | 
         
            +
                        shadow_map_size=2048,
         
     | 
| 342 | 
         
            +
                    )
         
     | 
| 343 | 
         
            +
                    scene.add_directional_light(
         
     | 
| 344 | 
         
            +
                        direction=[0, -0.5, 1], color=[0.8, 0.8, 0.85], shadow=False
         
     | 
| 345 | 
         
            +
                    )
         
     | 
| 346 | 
         
            +
                    scene.add_directional_light(
         
     | 
| 347 | 
         
            +
                        direction=[0, -1, 1], color=[1.0, 1.0, 1.0], shadow=False
         
     | 
| 348 | 
         
            +
                    )
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    ground_material = self.renderer.create_material()
         
     | 
| 351 | 
         
            +
                    ground_material.base_color = [0.5, 0.5, 0.5, 1]  # rgba, gray
         
     | 
| 352 | 
         
            +
                    ground_material.roughness = 0.7
         
     | 
| 353 | 
         
            +
                    ground_material.metallic = 0.0
         
     | 
| 354 | 
         
            +
                    scene.add_ground(0, render_material=ground_material)
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                    return scene
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                def step_action(
         
     | 
| 359 | 
         
            +
                    self,
         
     | 
| 360 | 
         
            +
                    agent: BaseAgent,
         
     | 
| 361 | 
         
            +
                    action: torch.Tensor,
         
     | 
| 362 | 
         
            +
                    cameras: list[sapien.render.RenderCameraComponent],
         
     | 
| 363 | 
         
            +
                    render_keys: list[str],
         
     | 
| 364 | 
         
            +
                    sim_steps_per_control: int = 1,
         
     | 
| 365 | 
         
            +
                ) -> dict:
         
     | 
| 366 | 
         
            +
                    agent.set_action(action)
         
     | 
| 367 | 
         
            +
                    frames = defaultdict(list)
         
     | 
| 368 | 
         
            +
                    for _ in range(sim_steps_per_control):
         
     | 
| 369 | 
         
            +
                        self.scene.step()
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                    self.scene.update_render()
         
     | 
| 372 | 
         
            +
                    for camera in cameras:
         
     | 
| 373 | 
         
            +
                        camera.take_picture()
         
     | 
| 374 | 
         
            +
                        images = render_images(camera, render_keys=render_keys)
         
     | 
| 375 | 
         
            +
                        frames[camera.name].append(images)
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    return frames
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                def create_camera(
         
     | 
| 380 | 
         
            +
                    self,
         
     | 
| 381 | 
         
            +
                    cam_name: str,
         
     | 
| 382 | 
         
            +
                    pose: sapien.Pose,
         
     | 
| 383 | 
         
            +
                    image_hw: tuple[int, int],
         
     | 
| 384 | 
         
            +
                    fovy_deg: float,
         
     | 
| 385 | 
         
            +
                ) -> sapien.render.RenderCameraComponent:
         
     | 
| 386 | 
         
            +
                    """Create a single camera in the scene.
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                    Args:
         
     | 
| 389 | 
         
            +
                        cam_name (str): Name of the camera.
         
     | 
| 390 | 
         
            +
                        pose (sapien.Pose): Camera pose p=(x, y, z), q=(w, x, y, z)
         
     | 
| 391 | 
         
            +
                        image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
         
     | 
| 392 | 
         
            +
                        fovy_deg (float): Field of view in degrees for cameras.
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                    Returns:
         
     | 
| 395 | 
         
            +
                        sapien.render.RenderCameraComponent: The created camera.
         
     | 
| 396 | 
         
            +
                    """
         
     | 
| 397 | 
         
            +
                    cam_actor = self.scene.create_actor_builder().build_kinematic()
         
     | 
| 398 | 
         
            +
                    cam_actor.set_pose(pose)
         
     | 
| 399 | 
         
            +
                    camera = self.scene.add_mounted_camera(
         
     | 
| 400 | 
         
            +
                        name=cam_name,
         
     | 
| 401 | 
         
            +
                        mount=cam_actor,
         
     | 
| 402 | 
         
            +
                        pose=sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0]),
         
     | 
| 403 | 
         
            +
                        width=image_hw[1],
         
     | 
| 404 | 
         
            +
                        height=image_hw[0],
         
     | 
| 405 | 
         
            +
                        fovy=np.deg2rad(fovy_deg),
         
     | 
| 406 | 
         
            +
                        near=0.01,
         
     | 
| 407 | 
         
            +
                        far=100,
         
     | 
| 408 | 
         
            +
                    )
         
     | 
| 409 | 
         
            +
                    self.cameras.append(camera)
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                    return camera
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                def initialize_circular_cameras(
         
     | 
| 414 | 
         
            +
                    self,
         
     | 
| 415 | 
         
            +
                    num_cameras: int,
         
     | 
| 416 | 
         
            +
                    radius: float,
         
     | 
| 417 | 
         
            +
                    height: float,
         
     | 
| 418 | 
         
            +
                    target_pt: list[float],
         
     | 
| 419 | 
         
            +
                    image_hw: tuple[int, int],
         
     | 
| 420 | 
         
            +
                    fovy_deg: float,
         
     | 
| 421 | 
         
            +
                ) -> list[sapien.render.RenderCameraComponent]:
         
     | 
| 422 | 
         
            +
                    """Initialize multiple cameras arranged in a circle.
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
                    Args:
         
     | 
| 425 | 
         
            +
                        num_cameras (int): Number of cameras to create.
         
     | 
| 426 | 
         
            +
                        radius (float): Radius of the camera circle.
         
     | 
| 427 | 
         
            +
                        height (float): Fixed Z-coordinate of the cameras.
         
     | 
| 428 | 
         
            +
                        target_pt (list[float]): 3D point (x, y, z) that cameras look at.
         
     | 
| 429 | 
         
            +
                        image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
         
     | 
| 430 | 
         
            +
                        fovy_deg (float): Field of view in degrees for cameras.
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
                    Returns:
         
     | 
| 433 | 
         
            +
                        List[sapien.render.RenderCameraComponent]: List of created cameras.
         
     | 
| 434 | 
         
            +
                    """
         
     | 
| 435 | 
         
            +
                    angle_step = 2 * np.pi / num_cameras
         
     | 
| 436 | 
         
            +
                    world_up_vec = np.array([0.0, 0.0, 1.0])
         
     | 
| 437 | 
         
            +
                    target_pt = np.array(target_pt)
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                    for i in range(num_cameras):
         
     | 
| 440 | 
         
            +
                        angle = i * angle_step
         
     | 
| 441 | 
         
            +
                        cam_x = radius * np.cos(angle)
         
     | 
| 442 | 
         
            +
                        cam_y = radius * np.sin(angle)
         
     | 
| 443 | 
         
            +
                        cam_z = height
         
     | 
| 444 | 
         
            +
                        eye_pos = [cam_x, cam_y, cam_z]
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                        forward_vec = target_pt - eye_pos
         
     | 
| 447 | 
         
            +
                        forward_vec = forward_vec / np.linalg.norm(forward_vec)
         
     | 
| 448 | 
         
            +
                        temp_right_vec = np.cross(forward_vec, world_up_vec)
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
                        if np.linalg.norm(temp_right_vec) < 1e-6:
         
     | 
| 451 | 
         
            +
                            temp_right_vec = np.array([1.0, 0.0, 0.0])
         
     | 
| 452 | 
         
            +
                            if np.abs(np.dot(temp_right_vec, forward_vec)) > 0.99:
         
     | 
| 453 | 
         
            +
                                temp_right_vec = np.array([0.0, 1.0, 0.0])
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                        right_vec = temp_right_vec / np.linalg.norm(temp_right_vec)
         
     | 
| 456 | 
         
            +
                        up_vec = np.cross(right_vec, forward_vec)
         
     | 
| 457 | 
         
            +
                        rotation_matrix = np.array([forward_vec, -right_vec, up_vec]).T
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
                        rot = R.from_matrix(rotation_matrix)
         
     | 
| 460 | 
         
            +
                        scipy_quat = rot.as_quat()  # (x, y, z, w)
         
     | 
| 461 | 
         
            +
                        quat = [
         
     | 
| 462 | 
         
            +
                            scipy_quat[3],
         
     | 
| 463 | 
         
            +
                            scipy_quat[0],
         
     | 
| 464 | 
         
            +
                            scipy_quat[1],
         
     | 
| 465 | 
         
            +
                            scipy_quat[2],
         
     | 
| 466 | 
         
            +
                        ]  # (w, x, y, z)
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                        self.create_camera(
         
     | 
| 469 | 
         
            +
                            f"camera_{i}",
         
     | 
| 470 | 
         
            +
                            sapien.Pose(p=eye_pos, q=quat),
         
     | 
| 471 | 
         
            +
                            image_hw,
         
     | 
| 472 | 
         
            +
                            fovy_deg,
         
     | 
| 473 | 
         
            +
                        )
         
     | 
| 474 | 
         
            +
             
     | 
| 475 | 
         
            +
                    return self.cameras
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
            class FrankaPandaGrasper(object):
         
     | 
| 479 | 
         
            +
                def __init__(
         
     | 
| 480 | 
         
            +
                    self,
         
     | 
| 481 | 
         
            +
                    agent: BaseAgent,
         
     | 
| 482 | 
         
            +
                    control_freq: float,
         
     | 
| 483 | 
         
            +
                    joint_vel_limits: float = 2.0,
         
     | 
| 484 | 
         
            +
                    joint_acc_limits: float = 1.0,
         
     | 
| 485 | 
         
            +
                    finger_length: float = 0.025,
         
     | 
| 486 | 
         
            +
                ) -> None:
         
     | 
| 487 | 
         
            +
                    self.agent = agent
         
     | 
| 488 | 
         
            +
                    self.robot = agent.robot
         
     | 
| 489 | 
         
            +
                    self.control_freq = control_freq
         
     | 
| 490 | 
         
            +
                    self.control_timestep = 1 / control_freq
         
     | 
| 491 | 
         
            +
                    self.joint_vel_limits = joint_vel_limits
         
     | 
| 492 | 
         
            +
                    self.joint_acc_limits = joint_acc_limits
         
     | 
| 493 | 
         
            +
                    self.finger_length = finger_length
         
     | 
| 494 | 
         
            +
                    self.planners = self._setup_planner()
         
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
                def _setup_planner(self) -> mplib.Planner:
         
     | 
| 497 | 
         
            +
                    planners = []
         
     | 
| 498 | 
         
            +
                    for pose in self.robot.pose:
         
     | 
| 499 | 
         
            +
                        link_names = [link.get_name() for link in self.robot.get_links()]
         
     | 
| 500 | 
         
            +
                        joint_names = [
         
     | 
| 501 | 
         
            +
                            joint.get_name() for joint in self.robot.get_active_joints()
         
     | 
| 502 | 
         
            +
                        ]
         
     | 
| 503 | 
         
            +
                        planner = mplib.Planner(
         
     | 
| 504 | 
         
            +
                            urdf=self.agent.urdf_path,
         
     | 
| 505 | 
         
            +
                            srdf=self.agent.urdf_path.replace(".urdf", ".srdf"),
         
     | 
| 506 | 
         
            +
                            user_link_names=link_names,
         
     | 
| 507 | 
         
            +
                            user_joint_names=joint_names,
         
     | 
| 508 | 
         
            +
                            move_group="panda_hand_tcp",
         
     | 
| 509 | 
         
            +
                            joint_vel_limits=np.ones(7) * self.joint_vel_limits,
         
     | 
| 510 | 
         
            +
                            joint_acc_limits=np.ones(7) * self.joint_acc_limits,
         
     | 
| 511 | 
         
            +
                        )
         
     | 
| 512 | 
         
            +
                        planner.set_base_pose(pose.raw_pose[0].tolist())
         
     | 
| 513 | 
         
            +
                        planners.append(planner)
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                    return planners
         
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
                def control_gripper(
         
     | 
| 518 | 
         
            +
                    self,
         
     | 
| 519 | 
         
            +
                    gripper_state: Literal[-1, 1],
         
     | 
| 520 | 
         
            +
                    n_step: int = 10,
         
     | 
| 521 | 
         
            +
                ) -> np.ndarray:
         
     | 
| 522 | 
         
            +
                    qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
         
     | 
| 523 | 
         
            +
                    actions = []
         
     | 
| 524 | 
         
            +
                    for _ in range(n_step):
         
     | 
| 525 | 
         
            +
                        action = np.hstack([qpos, gripper_state])[None, ...]
         
     | 
| 526 | 
         
            +
                        actions.append(action)
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                    return np.concatenate(actions, axis=0)
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
                def move_to_pose(
         
     | 
| 531 | 
         
            +
                    self,
         
     | 
| 532 | 
         
            +
                    pose: sapien.Pose,
         
     | 
| 533 | 
         
            +
                    control_timestep: float,
         
     | 
| 534 | 
         
            +
                    gripper_state: Literal[-1, 1],
         
     | 
| 535 | 
         
            +
                    use_point_cloud: bool = False,
         
     | 
| 536 | 
         
            +
                    n_max_step: int = 100,
         
     | 
| 537 | 
         
            +
                    action_key: str = "position",
         
     | 
| 538 | 
         
            +
                    env_idx: int = 0,
         
     | 
| 539 | 
         
            +
                ) -> np.ndarray:
         
     | 
| 540 | 
         
            +
                    result = self.planners[env_idx].plan_qpos_to_pose(
         
     | 
| 541 | 
         
            +
                        np.concatenate([pose.p, pose.q]),
         
     | 
| 542 | 
         
            +
                        self.robot.get_qpos().cpu().numpy()[0],
         
     | 
| 543 | 
         
            +
                        time_step=control_timestep,
         
     | 
| 544 | 
         
            +
                        use_point_cloud=use_point_cloud,
         
     | 
| 545 | 
         
            +
                    )
         
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
                    if result["status"] != "Success":
         
     | 
| 548 | 
         
            +
                        result = self.planners[env_idx].plan_screw(
         
     | 
| 549 | 
         
            +
                            np.concatenate([pose.p, pose.q]),
         
     | 
| 550 | 
         
            +
                            self.robot.get_qpos().cpu().numpy()[0],
         
     | 
| 551 | 
         
            +
                            time_step=control_timestep,
         
     | 
| 552 | 
         
            +
                            use_point_cloud=use_point_cloud,
         
     | 
| 553 | 
         
            +
                        )
         
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
                    if result["status"] != "Success":
         
     | 
| 556 | 
         
            +
                        return
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                    sample_ratio = (len(result[action_key]) // n_max_step) + 1
         
     | 
| 559 | 
         
            +
                    result[action_key] = result[action_key][::sample_ratio]
         
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                    n_step = len(result[action_key])
         
     | 
| 562 | 
         
            +
                    actions = []
         
     | 
| 563 | 
         
            +
                    for i in range(n_step):
         
     | 
| 564 | 
         
            +
                        qpos = result[action_key][i]
         
     | 
| 565 | 
         
            +
                        action = np.hstack([qpos, gripper_state])[None, ...]
         
     | 
| 566 | 
         
            +
                        actions.append(action)
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
                    return np.concatenate(actions, axis=0)
         
     | 
| 569 | 
         
            +
             
     | 
| 570 | 
         
            +
                def compute_grasp_action(
         
     | 
| 571 | 
         
            +
                    self,
         
     | 
| 572 | 
         
            +
                    actor: sapien.pysapien.Entity,
         
     | 
| 573 | 
         
            +
                    reach_target_only: bool = True,
         
     | 
| 574 | 
         
            +
                    offset: tuple[float, float, float] = [0, 0, -0.05],
         
     | 
| 575 | 
         
            +
                    env_idx: int = 0,
         
     | 
| 576 | 
         
            +
                ) -> np.ndarray:
         
     | 
| 577 | 
         
            +
                    physx_rigid = actor.components[1]
         
     | 
| 578 | 
         
            +
                    mesh = get_component_mesh(physx_rigid, to_world_frame=True)
         
     | 
| 579 | 
         
            +
                    obb = mesh.bounding_box_oriented
         
     | 
| 580 | 
         
            +
                    approaching = np.array([0, 0, -1])
         
     | 
| 581 | 
         
            +
                    tcp_pose = self.agent.tcp.pose[env_idx]
         
     | 
| 582 | 
         
            +
                    target_closing = (
         
     | 
| 583 | 
         
            +
                        tcp_pose.to_transformation_matrix()[0, :3, 1].cpu().numpy()
         
     | 
| 584 | 
         
            +
                    )
         
     | 
| 585 | 
         
            +
                    grasp_info = compute_grasp_info_by_obb(
         
     | 
| 586 | 
         
            +
                        obb,
         
     | 
| 587 | 
         
            +
                        approaching=approaching,
         
     | 
| 588 | 
         
            +
                        target_closing=target_closing,
         
     | 
| 589 | 
         
            +
                        depth=self.finger_length,
         
     | 
| 590 | 
         
            +
                    )
         
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
                    closing, center = grasp_info["closing"], grasp_info["center"]
         
     | 
| 593 | 
         
            +
                    raw_tcp_pose = tcp_pose.sp
         
     | 
| 594 | 
         
            +
                    grasp_pose = self.agent.build_grasp_pose(approaching, closing, center)
         
     | 
| 595 | 
         
            +
                    reach_pose = grasp_pose * sapien.Pose(p=offset)
         
     | 
| 596 | 
         
            +
                    grasp_pose = grasp_pose * sapien.Pose(p=[0, 0, 0.01])
         
     | 
| 597 | 
         
            +
                    actions = []
         
     | 
| 598 | 
         
            +
                    reach_actions = self.move_to_pose(
         
     | 
| 599 | 
         
            +
                        reach_pose,
         
     | 
| 600 | 
         
            +
                        self.control_timestep,
         
     | 
| 601 | 
         
            +
                        gripper_state=1,
         
     | 
| 602 | 
         
            +
                        env_idx=env_idx,
         
     | 
| 603 | 
         
            +
                    )
         
     | 
| 604 | 
         
            +
                    actions.append(reach_actions)
         
     | 
| 605 | 
         
            +
             
     | 
| 606 | 
         
            +
                    if reach_actions is None:
         
     | 
| 607 | 
         
            +
                        logger.warning(
         
     | 
| 608 | 
         
            +
                            f"Failed to reach the grasp pose for node `{actor.name}`, skipping grasping."
         
     | 
| 609 | 
         
            +
                        )
         
     | 
| 610 | 
         
            +
                        return None
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
                    if not reach_target_only:
         
     | 
| 613 | 
         
            +
                        grasp_actions = self.move_to_pose(
         
     | 
| 614 | 
         
            +
                            grasp_pose,
         
     | 
| 615 | 
         
            +
                            self.control_timestep,
         
     | 
| 616 | 
         
            +
                            gripper_state=1,
         
     | 
| 617 | 
         
            +
                            env_idx=env_idx,
         
     | 
| 618 | 
         
            +
                        )
         
     | 
| 619 | 
         
            +
                        actions.append(grasp_actions)
         
     | 
| 620 | 
         
            +
                        close_actions = self.control_gripper(
         
     | 
| 621 | 
         
            +
                            gripper_state=-1,
         
     | 
| 622 | 
         
            +
                            env_idx=env_idx,
         
     | 
| 623 | 
         
            +
                        )
         
     | 
| 624 | 
         
            +
                        actions.append(close_actions)
         
     | 
| 625 | 
         
            +
                        back_actions = self.move_to_pose(
         
     | 
| 626 | 
         
            +
                            raw_tcp_pose,
         
     | 
| 627 | 
         
            +
                            self.control_timestep,
         
     | 
| 628 | 
         
            +
                            gripper_state=-1,
         
     | 
| 629 | 
         
            +
                            env_idx=env_idx,
         
     | 
| 630 | 
         
            +
                        )
         
     | 
| 631 | 
         
            +
                        actions.append(back_actions)
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
                    return np.concatenate(actions, axis=0)
         
     | 
    	
        embodied_gen/utils/tags.py
    CHANGED
    
    | 
         @@ -1 +1 @@ 
     | 
|
| 1 | 
         
            -
            VERSION = "v0.1. 
     | 
| 
         | 
|
| 1 | 
         
            +
            VERSION = "v0.1.3"
         
     | 
    	
        embodied_gen/validators/quality_checkers.py
    CHANGED
    
    | 
         @@ -109,7 +109,7 @@ class MeshGeoChecker(BaseChecker): 
     | 
|
| 109 | 
         
             
                    if self.prompt is None:
         
     | 
| 110 | 
         
             
                        self.prompt = """
         
     | 
| 111 | 
         
             
                        You are an expert in evaluating the geometry quality of generated 3D asset.
         
     | 
| 112 | 
         
            -
                        You will be given rendered views of a generated 3D asset with black background.
         
     | 
| 113 | 
         
             
                        Your task is to evaluate the quality of the 3D asset generation,
         
     | 
| 114 | 
         
             
                        including geometry, structure, and appearance, based on the rendered views.
         
     | 
| 115 | 
         
             
                        Criteria:
         
     | 
| 
         @@ -130,10 +130,13 @@ class MeshGeoChecker(BaseChecker): 
     | 
|
| 130 | 
         
             
                        Image shows a chair with simplified back legs and soft edges β YES
         
     | 
| 131 | 
         
             
                        """
         
     | 
| 132 | 
         | 
| 133 | 
         
            -
                def query( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 134 | 
         | 
| 135 | 
         
             
                    return self.gpt_client.query(
         
     | 
| 136 | 
         
            -
                        text_prompt= 
     | 
| 137 | 
         
             
                        image_base64=image_paths,
         
     | 
| 138 | 
         
             
                    )
         
     | 
| 139 | 
         | 
| 
         | 
|
| 109 | 
         
             
                    if self.prompt is None:
         
     | 
| 110 | 
         
             
                        self.prompt = """
         
     | 
| 111 | 
         
             
                        You are an expert in evaluating the geometry quality of generated 3D asset.
         
     | 
| 112 | 
         
            +
                        You will be given rendered views of a generated 3D asset, type {}, with black background.
         
     | 
| 113 | 
         
             
                        Your task is to evaluate the quality of the 3D asset generation,
         
     | 
| 114 | 
         
             
                        including geometry, structure, and appearance, based on the rendered views.
         
     | 
| 115 | 
         
             
                        Criteria:
         
     | 
| 
         | 
|
| 130 | 
         
             
                        Image shows a chair with simplified back legs and soft edges β YES
         
     | 
| 131 | 
         
             
                        """
         
     | 
| 132 | 
         | 
| 133 | 
         
            +
                def query(
         
     | 
| 134 | 
         
            +
                    self, image_paths: list[str | Image.Image], text: str = "unknown"
         
     | 
| 135 | 
         
            +
                ) -> str:
         
     | 
| 136 | 
         
            +
                    input_prompt = self.prompt.format(text)
         
     | 
| 137 | 
         | 
| 138 | 
         
             
                    return self.gpt_client.query(
         
     | 
| 139 | 
         
            +
                        text_prompt=input_prompt,
         
     | 
| 140 | 
         
             
                        image_base64=image_paths,
         
     | 
| 141 | 
         
             
                    )
         
     | 
| 142 | 
         | 
    	
        embodied_gen/validators/urdf_convertor.py
    CHANGED
    
    | 
         @@ -24,6 +24,7 @@ from xml.dom.minidom import parseString 
     | 
|
| 24 | 
         | 
| 25 | 
         
             
            import numpy as np
         
     | 
| 26 | 
         
             
            import trimesh
         
     | 
| 
         | 
|
| 27 | 
         
             
            from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
         
     | 
| 28 | 
         
             
            from embodied_gen.utils.process_media import render_asset3d
         
     | 
| 29 | 
         
             
            from embodied_gen.utils.tags import VERSION
         
     | 
| 
         @@ -84,6 +85,7 @@ class URDFGenerator(object): 
     | 
|
| 84 | 
         
             
                    attrs_name: list[str] = None,
         
     | 
| 85 | 
         
             
                    render_dir: str = "urdf_renders",
         
     | 
| 86 | 
         
             
                    render_view_num: int = 4,
         
     | 
| 
         | 
|
| 87 | 
         
             
                ) -> None:
         
     | 
| 88 | 
         
             
                    if mesh_file_list is None:
         
     | 
| 89 | 
         
             
                        mesh_file_list = []
         
     | 
| 
         @@ -107,36 +109,37 @@ class URDFGenerator(object): 
     | 
|
| 107 | 
         
             
                        already provided, use it directly), accurately describe this 3D object asset (within 15 words),
         
     | 
| 108 | 
         
             
                        Determine the pose of the object in the first image and estimate the true vertical height
         
     | 
| 109 | 
         
             
                        (vertical projection) range of the object (in meters), i.e., how tall the object appears from top
         
     | 
| 110 | 
         
            -
                        to bottom in the  
     | 
| 111 | 
         
             
                        static friction coefficient of the object relative to rubber and the average dynamic friction
         
     | 
| 112 | 
         
            -
                        coefficient of the object relative to rubber. Return response format as shown in Output Example.
         
     | 
| 113 | 
         | 
| 114 | 
         
             
                        Output Example:
         
     | 
| 115 | 
         
             
                        Category: cup
         
     | 
| 116 | 
         
             
                        Description: shiny golden cup with floral design
         
     | 
| 117 | 
         
            -
                         
     | 
| 
         | 
|
| 118 | 
         
             
                        Weight: 0.3-0.6 kg
         
     | 
| 119 | 
         
             
                        Static friction coefficient: 0.6
         
     | 
| 120 | 
         
             
                        Dynamic friction coefficient: 0.5
         
     | 
| 121 | 
         | 
| 122 | 
         
            -
                        IMPORTANT: Estimating Vertical Height from the First (Front View) Image.
         
     | 
| 123 | 
         
             
                        - The "vertical height" refers to the real-world vertical size of the object
         
     | 
| 124 | 
         
             
                        as projected in the first image, aligned with the image's vertical axis.
         
     | 
| 125 | 
         
             
                        - For flat objects like plates or disks or book, if their face is visible in the front view,
         
     | 
| 126 | 
         
             
                        use the diameter as the vertical height. If the edge is visible, use the thickness instead.
         
     | 
| 127 | 
         
             
                        - This is not necessarily the full length of the object, but how tall it appears
         
     | 
| 128 | 
         
            -
                        in the first image vertically, based on its pose and orientation.
         
     | 
| 129 | 
         
            -
                        - For objects(e.g., spoons, forks, writing instruments etc.) at an angle showing in
         
     | 
| 130 | 
         
            -
             
     | 
| 131 | 
         
             
                        Estimate the vertical projection of their real length based on its pose.
         
     | 
| 132 | 
         
             
                        For example:
         
     | 
| 133 | 
         
            -
                          - A pen standing upright in the first  
     | 
| 134 | 
         
             
                          full body visible in the first image: β vertical height β 0.14-0.20 m
         
     | 
| 135 | 
         
            -
                          - A pen lying flat in the  
     | 
| 136 | 
         
             
                          - Tilted pen in the first image (e.g., ~45Β° angle): vertical height β 0.07-0.12 m
         
     | 
| 137 | 
         
            -
                        - Use the rest views 
     | 
| 138 | 
         
             
                        Assume the object is in real-world scale and estimate the approximate vertical height
         
     | 
| 139 | 
         
            -
                         
     | 
| 140 | 
         
             
                        """
         
     | 
| 141 | 
         
             
                        )
         
     | 
| 142 | 
         | 
| 
         @@ -155,6 +158,7 @@ class URDFGenerator(object): 
     | 
|
| 155 | 
         
             
                            "gs_model",
         
     | 
| 156 | 
         
             
                        ]
         
     | 
| 157 | 
         
             
                    self.attrs_name = attrs_name
         
     | 
| 
         | 
|
| 158 | 
         | 
| 159 | 
         
             
                def parse_response(self, response: str) -> dict[str, any]:
         
     | 
| 160 | 
         
             
                    lines = response.split("\n")
         
     | 
| 
         @@ -163,14 +167,14 @@ class URDFGenerator(object): 
     | 
|
| 163 | 
         
             
                    description = lines[1].split(": ")[1]
         
     | 
| 164 | 
         
             
                    min_height, max_height = map(
         
     | 
| 165 | 
         
             
                        lambda x: float(x.strip().replace(",", "").split()[0]),
         
     | 
| 166 | 
         
            -
                        lines[ 
     | 
| 167 | 
         
             
                    )
         
     | 
| 168 | 
         
             
                    min_mass, max_mass = map(
         
     | 
| 169 | 
         
             
                        lambda x: float(x.strip().replace(",", "").split()[0]),
         
     | 
| 170 | 
         
            -
                        lines[ 
     | 
| 171 | 
         
             
                    )
         
     | 
| 172 | 
         
            -
                    mu1 = float(lines[ 
     | 
| 173 | 
         
            -
                    mu2 = float(lines[ 
     | 
| 174 | 
         | 
| 175 | 
         
             
                    return {
         
     | 
| 176 | 
         
             
                        "category": category.lower(),
         
     | 
| 
         @@ -257,9 +261,24 @@ class URDFGenerator(object): 
     | 
|
| 257 | 
         
             
                    # Update collision geometry
         
     | 
| 258 | 
         
             
                    collision = link.find("collision/geometry/mesh")
         
     | 
| 259 | 
         
             
                    if collision is not None:
         
     | 
| 260 | 
         
            -
                         
     | 
| 261 | 
         
            -
             
     | 
| 262 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 263 | 
         
             
                        collision.set("scale", "1.0 1.0 1.0")
         
     | 
| 264 | 
         | 
| 265 | 
         
             
                    # Update friction coefficients
         
     | 
| 
         | 
|
| 24 | 
         | 
| 25 | 
         
             
            import numpy as np
         
     | 
| 26 | 
         
             
            import trimesh
         
     | 
| 27 | 
         
            +
            from embodied_gen.data.convex_decomposer import decompose_convex_mesh
         
     | 
| 28 | 
         
             
            from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
         
     | 
| 29 | 
         
             
            from embodied_gen.utils.process_media import render_asset3d
         
     | 
| 30 | 
         
             
            from embodied_gen.utils.tags import VERSION
         
     | 
| 
         | 
|
| 85 | 
         
             
                    attrs_name: list[str] = None,
         
     | 
| 86 | 
         
             
                    render_dir: str = "urdf_renders",
         
     | 
| 87 | 
         
             
                    render_view_num: int = 4,
         
     | 
| 88 | 
         
            +
                    decompose_convex: bool = False,
         
     | 
| 89 | 
         
             
                ) -> None:
         
     | 
| 90 | 
         
             
                    if mesh_file_list is None:
         
     | 
| 91 | 
         
             
                        mesh_file_list = []
         
     | 
| 
         | 
|
| 109 | 
         
             
                        already provided, use it directly), accurately describe this 3D object asset (within 15 words),
         
     | 
| 110 | 
         
             
                        Determine the pose of the object in the first image and estimate the true vertical height
         
     | 
| 111 | 
         
             
                        (vertical projection) range of the object (in meters), i.e., how tall the object appears from top
         
     | 
| 112 | 
         
            +
                        to bottom in the first image. also weight range (unit: kilogram), the average
         
     | 
| 113 | 
         
             
                        static friction coefficient of the object relative to rubber and the average dynamic friction
         
     | 
| 114 | 
         
            +
                        coefficient of the object relative to rubber. Return response in format as shown in Output Example.
         
     | 
| 115 | 
         | 
| 116 | 
         
             
                        Output Example:
         
     | 
| 117 | 
         
             
                        Category: cup
         
     | 
| 118 | 
         
             
                        Description: shiny golden cup with floral design
         
     | 
| 119 | 
         
            +
                        Pose: <short_description_within_10_words>
         
     | 
| 120 | 
         
            +
                        Height: 0.10-0.15 m
         
     | 
| 121 | 
         
             
                        Weight: 0.3-0.6 kg
         
     | 
| 122 | 
         
             
                        Static friction coefficient: 0.6
         
     | 
| 123 | 
         
             
                        Dynamic friction coefficient: 0.5
         
     | 
| 124 | 
         | 
| 125 | 
         
            +
                        IMPORTANT: Estimating Vertical Height from the First (Front View) Image and pose estimation based on all views.
         
     | 
| 126 | 
         
             
                        - The "vertical height" refers to the real-world vertical size of the object
         
     | 
| 127 | 
         
             
                        as projected in the first image, aligned with the image's vertical axis.
         
     | 
| 128 | 
         
             
                        - For flat objects like plates or disks or book, if their face is visible in the front view,
         
     | 
| 129 | 
         
             
                        use the diameter as the vertical height. If the edge is visible, use the thickness instead.
         
     | 
| 130 | 
         
             
                        - This is not necessarily the full length of the object, but how tall it appears
         
     | 
| 131 | 
         
            +
                        in the first image vertically, based on its pose and orientation estimation on all views.
         
     | 
| 132 | 
         
            +
                        - For objects(e.g., spoons, forks, writing instruments etc.) at an angle showing in images,
         
     | 
| 133 | 
         
            +
                            e.g., tilted at 45Β° will appear shorter vertically than when upright.
         
     | 
| 134 | 
         
             
                        Estimate the vertical projection of their real length based on its pose.
         
     | 
| 135 | 
         
             
                        For example:
         
     | 
| 136 | 
         
            +
                          - A pen standing upright in the first image (aligned with the image's vertical axis)
         
     | 
| 137 | 
         
             
                          full body visible in the first image: β vertical height β 0.14-0.20 m
         
     | 
| 138 | 
         
            +
                          - A pen lying flat in the first image (showing thickness or as a dot) β vertical height β 0.018-0.025 m
         
     | 
| 139 | 
         
             
                          - Tilted pen in the first image (e.g., ~45Β° angle): vertical height β 0.07-0.12 m
         
     | 
| 140 | 
         
            +
                        - Use the rest views to help determine the object's 3D pose and orientation.
         
     | 
| 141 | 
         
             
                        Assume the object is in real-world scale and estimate the approximate vertical height
         
     | 
| 142 | 
         
            +
                        based on the pose estimation and how large it appears vertically in the first image.
         
     | 
| 143 | 
         
             
                        """
         
     | 
| 144 | 
         
             
                        )
         
     | 
| 145 | 
         | 
| 
         | 
|
| 158 | 
         
             
                            "gs_model",
         
     | 
| 159 | 
         
             
                        ]
         
     | 
| 160 | 
         
             
                    self.attrs_name = attrs_name
         
     | 
| 161 | 
         
            +
                    self.decompose_convex = decompose_convex
         
     | 
| 162 | 
         | 
| 163 | 
         
             
                def parse_response(self, response: str) -> dict[str, any]:
         
     | 
| 164 | 
         
             
                    lines = response.split("\n")
         
     | 
| 
         | 
|
| 167 | 
         
             
                    description = lines[1].split(": ")[1]
         
     | 
| 168 | 
         
             
                    min_height, max_height = map(
         
     | 
| 169 | 
         
             
                        lambda x: float(x.strip().replace(",", "").split()[0]),
         
     | 
| 170 | 
         
            +
                        lines[3].split(": ")[1].split("-"),
         
     | 
| 171 | 
         
             
                    )
         
     | 
| 172 | 
         
             
                    min_mass, max_mass = map(
         
     | 
| 173 | 
         
             
                        lambda x: float(x.strip().replace(",", "").split()[0]),
         
     | 
| 174 | 
         
            +
                        lines[4].split(": ")[1].split("-"),
         
     | 
| 175 | 
         
             
                    )
         
     | 
| 176 | 
         
            +
                    mu1 = float(lines[5].split(": ")[1].replace(",", ""))
         
     | 
| 177 | 
         
            +
                    mu2 = float(lines[6].split(": ")[1].replace(",", ""))
         
     | 
| 178 | 
         | 
| 179 | 
         
             
                    return {
         
     | 
| 180 | 
         
             
                        "category": category.lower(),
         
     | 
| 
         | 
|
| 261 | 
         
             
                    # Update collision geometry
         
     | 
| 262 | 
         
             
                    collision = link.find("collision/geometry/mesh")
         
     | 
| 263 | 
         
             
                    if collision is not None:
         
     | 
| 264 | 
         
            +
                        collision_mesh = os.path.join(self.output_mesh_dir, obj_name)
         
     | 
| 265 | 
         
            +
                        if self.decompose_convex:
         
     | 
| 266 | 
         
            +
                            try:
         
     | 
| 267 | 
         
            +
                                d_params = dict(
         
     | 
| 268 | 
         
            +
                                    threshold=0.05, max_convex_hull=64, verbose=False
         
     | 
| 269 | 
         
            +
                                )
         
     | 
| 270 | 
         
            +
                                filename = f"{os.path.splitext(obj_name)[0]}_collision.ply"
         
     | 
| 271 | 
         
            +
                                output_path = os.path.join(mesh_folder, filename)
         
     | 
| 272 | 
         
            +
                                decompose_convex_mesh(
         
     | 
| 273 | 
         
            +
                                    mesh_output_path, output_path, **d_params
         
     | 
| 274 | 
         
            +
                                )
         
     | 
| 275 | 
         
            +
                                collision_mesh = f"{self.output_mesh_dir}/{filename}"
         
     | 
| 276 | 
         
            +
                            except Exception as e:
         
     | 
| 277 | 
         
            +
                                logger.warning(
         
     | 
| 278 | 
         
            +
                                    f"Convex decomposition failed for {output_path}, {e}."
         
     | 
| 279 | 
         
            +
                                    "Use original mesh for collision computation."
         
     | 
| 280 | 
         
            +
                                )
         
     | 
| 281 | 
         
            +
                        collision.set("filename", collision_mesh)
         
     | 
| 282 | 
         
             
                        collision.set("scale", "1.0 1.0 1.0")
         
     | 
| 283 | 
         | 
| 284 | 
         
             
                    # Update friction coefficients
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -50,4 +50,5 @@ tyro 
     | 
|
| 50 | 
         
             
            pyquaternion
         
     | 
| 51 | 
         
             
            shapely
         
     | 
| 52 | 
         
             
            sapien==3.0.0b1
         
     | 
| 
         | 
|
| 53 | 
         
             
            typing_extensions==4.14.1
         
     | 
| 
         | 
|
| 50 | 
         
             
            pyquaternion
         
     | 
| 51 | 
         
             
            shapely
         
     | 
| 52 | 
         
             
            sapien==3.0.0b1
         
     | 
| 53 | 
         
            +
            coacd
         
     | 
| 54 | 
         
             
            typing_extensions==4.14.1
         
     |