Spaces:
Runtime error
Runtime error
Wei Liu
CPU-first startup: load all models/scenes to CPU at module level, GPU transfer at generation time
0cdce4a | """Interactive Genesis simulation wrapper for the demo. | |
| Loads pre-computed 3D reconstruction results and runs physics simulation | |
| with user-controlled forces, rendering frames and computing optical flow. | |
| """ | |
| import base64 | |
| import io | |
| import os | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import trimesh | |
| import genesis as gs | |
| from omegaconf import OmegaConf | |
| from simulation.utils import pt3d_to_gs, gs_to_pt3d, pose_to_transform_matrix | |
| from simulation.case_simulation.case_handler import get_case_handler | |
| from pytorch3d.renderer import PerspectiveCameras | |
| from PIL import Image | |
| _genesis_initialized = False | |
| class InteractiveSimulator: | |
| """Wraps Genesis simulation for interactive force control.""" | |
| def __init__(self, demo_data_path: str, device: str = "cuda", | |
| config_overrides: dict | None = None): | |
| self.demo_data_path = Path(demo_data_path) | |
| self.device = torch.device(device) | |
| self.config = OmegaConf.to_container( | |
| OmegaConf.load(self.demo_data_path / "config.yaml"), resolve=True | |
| ) | |
| self.config["device"] = device | |
| self.config["output_folder"] = str(self.demo_data_path / "sim_tmp") | |
| os.makedirs(self.config["output_folder"], exist_ok=True) | |
| self.config.setdefault("debug", False) | |
| if config_overrides: | |
| self.config.update(config_overrides) | |
| self.dt = self.config.get("dt", 0.01) | |
| self.substeps = self.config.get("substeps", 10) | |
| self.frame_steps = self.config.get("frame_steps", 5) | |
| self.material_type = self.config["material_type"] | |
| self.crop_start = self.config.get("crop_start", 176) | |
| self.object_masks_b64 = self._load_object_masks() | |
| self.demo_case_handler = None | |
| self._setup_scene() | |
| def _setup_scene(self): | |
| """Load pre-computed data and build Genesis scene.""" | |
| meshes_dir = self.demo_data_path / "fg_meshes" | |
| pcs_dir = self.demo_data_path / "fg_pcs" | |
| mesh_files = sorted(meshes_dir.glob("mesh_*.obj")) | |
| pc_files = sorted(pcs_dir.glob("pc_*.pt")) | |
| self.fg_meshes = [] | |
| for mf in mesh_files: | |
| mesh = trimesh.load(str(mf), process=False) | |
| self.fg_meshes.append({ | |
| "vertices": torch.from_numpy(mesh.vertices).to(self.device).float(), | |
| "faces": torch.from_numpy(mesh.faces).to(self.device).long(), | |
| "colors": torch.from_numpy( | |
| np.array(mesh.visual.vertex_colors)[:, :3] / 255.0 | |
| ).to(self.device).float(), | |
| }) | |
| self.fg_pcs_pt3d = [] | |
| self.fg_pcs_gs = [] | |
| for pf in pc_files: | |
| data = torch.load(pf, map_location=self.device) | |
| self.fg_pcs_pt3d.append({ | |
| "points": data["points"].to(self.device), | |
| "colors": data["colors"].to(self.device), | |
| }) | |
| self.fg_pcs_gs.append({ | |
| "points": pt3d_to_gs(data["points"].clone().to(self.device)), | |
| "colors": data["colors"].to(self.device), | |
| }) | |
| for mesh_info in self.fg_meshes: | |
| mesh_info["vertices"] = pt3d_to_gs(mesh_info["vertices"]) | |
| cam_data = torch.load(self.demo_data_path / "camera.pt", map_location=self.device) | |
| bg_data = torch.load(self.demo_data_path / "bg_points.pt", map_location=self.device) | |
| gn_path = self.demo_data_path / "ground_plane_normal.npy" | |
| self.ground_plane_normal = None | |
| if gn_path.exists(): | |
| self.ground_plane_normal = pt3d_to_gs(np.load(gn_path)) | |
| if self.ground_plane_normal[2] < 0: | |
| self.ground_plane_normal = -self.ground_plane_normal | |
| self._setup_renderer(cam_data, bg_data) | |
| self._setup_genesis() | |
| def _setup_renderer(self, cam_data, bg_data): | |
| camera = PerspectiveCameras( | |
| K=cam_data["K"].to(self.device), | |
| R=cam_data["R"].to(self.device), | |
| T=cam_data["T"].to(self.device), | |
| in_ndc=False, | |
| image_size=((512, 512),), | |
| device=self.device, | |
| ) | |
| self.svr = _MinimalSVR( | |
| config=self.config, | |
| camera=camera, | |
| focal_length=cam_data["focal_length"], | |
| bg_points=bg_data["points"].to(self.device), | |
| bg_points_colors=bg_data["colors"].to(self.device), | |
| fg_pcs=[{ | |
| "points": pc["points"].clone(), | |
| "colors": pc["colors"].clone(), | |
| } for pc in self.fg_pcs_pt3d], | |
| device=self.device, | |
| ) | |
| def _setup_genesis(self): | |
| all_obj_info = [] | |
| all_lower = torch.tensor([float("inf")] * 3, device=self.device) | |
| all_upper = torch.tensor([float("-inf")] * 3, device=self.device) | |
| for idx, mesh_info in enumerate(self.fg_meshes): | |
| vmin = mesh_info["vertices"].min(0).values | |
| vmax = mesh_info["vertices"].max(0).values | |
| center = mesh_info["vertices"].mean(0) | |
| size = vmax - vmin | |
| mesh_info["vertices"] -= center | |
| mesh_path = os.path.join(self.config["output_folder"], f"fg_mesh_{idx:02d}.obj") | |
| t = trimesh.Trimesh( | |
| vertices=mesh_info["vertices"].cpu().numpy(), | |
| faces=mesh_info["faces"].cpu().numpy(), | |
| vertex_colors=mesh_info["colors"].cpu().numpy(), | |
| ) | |
| t.export(mesh_path) | |
| all_obj_info.append({ | |
| "min": vmin, "max": vmax, "center": center, "size": size, | |
| "mesh_path": mesh_path, | |
| "vertices": mesh_info["vertices"] + center, | |
| }) | |
| all_lower = torch.minimum(all_lower, vmin) | |
| all_upper = torch.maximum(all_upper, vmax) | |
| self.all_obj_info = all_obj_info | |
| self.case_handler = get_case_handler( | |
| self.config["example_name"], self.config, all_obj_info, self.device | |
| ) | |
| self.case_handler.set_simulation_bounds(all_lower, all_upper) | |
| sim_lower, sim_upper = self.case_handler.get_simulation_bounds() | |
| gravity_dir = ( | |
| self.ground_plane_normal.copy() | |
| if self.ground_plane_normal is not None | |
| else np.array([0, 0, 1]) | |
| ) | |
| if "gravity" in self.config: | |
| if isinstance(self.config["gravity"], (int, float)): | |
| gravity = tuple(self.config["gravity"] * gravity_dir) | |
| else: | |
| gravity = tuple(pt3d_to_gs(np.array(self.config["gravity"]))) | |
| else: | |
| gravity = tuple(-9.8 * gravity_dir) | |
| pbd_gravity = None | |
| if "pbd_gravity" in self.config: | |
| if isinstance(self.config["pbd_gravity"], (int, float)): | |
| pbd_gravity = tuple(self.config["pbd_gravity"] * gravity_dir) | |
| else: | |
| pbd_gravity = tuple(pt3d_to_gs(np.array(self.config["pbd_gravity"]))) | |
| mpm_gravity = None | |
| if "mpm_gravity" in self.config: | |
| if isinstance(self.config["mpm_gravity"], (int, float)): | |
| mpm_gravity = tuple(self.config["mpm_gravity"] * gravity_dir) | |
| else: | |
| mpm_gravity = tuple(pt3d_to_gs(np.array(self.config["mpm_gravity"]))) | |
| global _genesis_initialized | |
| if not _genesis_initialized: | |
| gs.init(seed=self.config.get("seed", 0), precision="32", | |
| backend=gs.cpu, logging_level="warning") | |
| _genesis_initialized = True | |
| self.scene = gs.Scene( | |
| sim_options=gs.options.SimOptions( | |
| dt=self.dt, gravity=gravity, substeps=self.substeps, | |
| ), | |
| show_viewer=False, | |
| vis_options=gs.options.VisOptions( | |
| show_world_frame=False, | |
| show_link_frame=False, | |
| show_cameras=False, | |
| plane_reflection=False, | |
| ambient_light=(0.5, 0.5, 0.5), | |
| lights=[{ | |
| "type": "directional", | |
| "dir": (0, 0, 1), | |
| "color": (1.0, 1.0, 1.0), | |
| "intensity": 2.0, | |
| }], | |
| ), | |
| renderer=gs.renderers.Rasterizer(), | |
| rigid_options=gs.options.RigidOptions( | |
| dt=self.dt, | |
| enable_collision=True, | |
| enable_self_collision=False, | |
| constraint_timeconst=0.02, | |
| ), | |
| pbd_options=gs.options.PBDOptions( | |
| lower_bound=tuple(sim_lower), | |
| upper_bound=tuple(sim_upper), | |
| particle_size=self.config.get("particle_size", 0.01), | |
| gravity=pbd_gravity, | |
| ), | |
| mpm_options=gs.options.MPMOptions( | |
| lower_bound=tuple(sim_lower), | |
| upper_bound=tuple(sim_upper), | |
| grid_density=self.config.get("MPM_grid_density", 64), | |
| particle_size=self.config.get("particle_size", 0.01), | |
| gravity=mpm_gravity, | |
| ), | |
| coupler_options=gs.options.LegacyCouplerOptions( | |
| rigid_pbd=True, rigid_mpm=True, | |
| ), | |
| ) | |
| obj_materials = [] | |
| obj_vis_modes = [] | |
| for mt in self.material_type: | |
| mat, vis = self._get_material(mt) | |
| obj_materials.append(mat) | |
| obj_vis_modes.append(vis) | |
| self.objs = self.case_handler.add_entities_to_scene( | |
| self.scene, obj_materials, obj_vis_modes | |
| ) | |
| self.case_handler.before_scene_building( | |
| self.scene, self.objs, self.ground_plane_normal | |
| ) | |
| self.debug_cam = None | |
| self._debug_cam_failed = False | |
| if self.config.get("debug", False): | |
| self.debug_cam = self.scene.add_camera( | |
| res=(512, 512), | |
| pos=(0, -1, 0), | |
| lookat=(0, 1, 0), | |
| fov=self.config.get("fov_x_input", 60), | |
| GUI=False, | |
| ) | |
| self._debug_output = Path(self.config["output_folder"]) | |
| self._debug_gs_frames = self._debug_output / "gs_frames" | |
| self._debug_gs_frames.mkdir(parents=True, exist_ok=True) | |
| self.scene.build() | |
| self.case_handler.after_scene_building() | |
| for _ in range(3): | |
| self.scene.step() | |
| self.scene.reset() | |
| self.case_handler.fix_particles() | |
| self.initial_transform_matrix = {} | |
| self.closest_indices = {} | |
| for obj_idx, mt in enumerate(self.material_type): | |
| if mt == "rigid": | |
| self.objs[obj_idx].solver.update_vgeoms_render_T() | |
| rigid_T = self.objs[obj_idx].solver._vgeoms_render_T | |
| rigid_idx = self.objs[obj_idx].idx | |
| self.initial_transform_matrix[obj_idx] = ( | |
| torch.tensor(rigid_T[rigid_idx, 0]).to(self.device).float() | |
| ) | |
| elif mt in ("pbd_liquid", "pbd_cloth", "mpm_sand", "mpm_liquid", | |
| "mpm_elastic", "mpm_snow", "mpm_elastic2plastic", | |
| "pbd_elastic", "pbd_particle"): | |
| self.closest_indices[obj_idx] = self._map_pc_to_particles(obj_idx) | |
| self._init_particles_gpu = { | |
| obj_idx: torch.tensor( | |
| self.objs[obj_idx].init_particles, | |
| device=self.device, dtype=torch.float32, | |
| ) | |
| for obj_idx in self.closest_indices | |
| } | |
| self.step_count = 0 | |
| print("Genesis scene construction finished") | |
| def set_demo_case_handler(self, handler): | |
| self.demo_case_handler = handler | |
| def move_to_device(self, device): | |
| """Move all renderer/simulation tensors to target device (CPU↔GPU).""" | |
| dev = torch.device(device) | |
| self.device = dev | |
| # Move SVR (PyTorch3D renderer + camera + point clouds) | |
| self.svr.move_to_device(dev) | |
| # Move mesh data | |
| for mesh in self.fg_meshes: | |
| for k, v in list(mesh.items()): | |
| if isinstance(v, torch.Tensor): | |
| mesh[k] = v.to(dev) | |
| # Move foreground point clouds | |
| for pc_list in (self.fg_pcs_pt3d, self.fg_pcs_gs): | |
| for pc in pc_list: | |
| for k, v in list(pc.items()): | |
| if isinstance(v, torch.Tensor): | |
| pc[k] = v.to(dev) | |
| # Move per-object transform matrices and initial particles | |
| for k in list(self.initial_transform_matrix.keys()): | |
| self.initial_transform_matrix[k] = self.initial_transform_matrix[k].to(dev) | |
| for k in list(self._init_particles_gpu.keys()): | |
| self._init_particles_gpu[k] = self._init_particles_gpu[k].to(dev) | |
| # Move obj_info tensors (shared with case_handler by reference) | |
| for obj_info in self.all_obj_info: | |
| for k, v in list(obj_info.items()): | |
| if isinstance(v, torch.Tensor): | |
| obj_info[k] = v.to(dev) | |
| def _load_object_masks(self): | |
| masks_dir = self.demo_data_path / "fg_masks" | |
| if not masks_dir.exists(): | |
| return [] | |
| mask_files = sorted(masks_dir.glob("mask_*.png")) | |
| masks_b64 = [] | |
| for mf in mask_files: | |
| with open(mf, "rb") as f: | |
| masks_b64.append(base64.b64encode(f.read()).decode("ascii")) | |
| return masks_b64 | |
| def step(self, extract_points=True): | |
| """Run one simulation step with interactive force applied.""" | |
| if self.demo_case_handler is not None: | |
| self.demo_case_handler.apply_forces(self, self.step_count) | |
| if self.debug_cam is not None and not self._debug_cam_failed: | |
| try: | |
| self.debug_cam.start_recording() | |
| except Exception: | |
| self._debug_cam_failed = True | |
| self.scene.step() | |
| if self.debug_cam is not None and not self._debug_cam_failed: | |
| try: | |
| render_out = self.debug_cam.render() | |
| cv2.imwrite( | |
| str(self._debug_gs_frames / f"{self.step_count:04d}.png"), | |
| render_out[0], | |
| ) | |
| except Exception: | |
| self._debug_cam_failed = True | |
| self.step_count += 1 | |
| if not extract_points: | |
| return None | |
| updated_all_obj_points = [] | |
| for obj_idx, mt in enumerate(self.material_type): | |
| if mt == "rigid": | |
| pos = self.objs[obj_idx].get_pos().cpu().numpy() | |
| quat = self.objs[obj_idx].get_quat().cpu().numpy() | |
| T = torch.from_numpy( | |
| pose_to_transform_matrix(pos, quat) | |
| ).to(self.device).float() | |
| T_inv = torch.linalg.inv(self.initial_transform_matrix[obj_idx]) | |
| real_T = T @ T_inv | |
| pts_h = torch.cat([ | |
| self.fg_pcs_gs[obj_idx]["points"], | |
| torch.ones(self.fg_pcs_gs[obj_idx]["points"].shape[0], 1, device=self.device), | |
| ], dim=1) | |
| updated = (real_T.unsqueeze(0) @ pts_h.unsqueeze(-1)).squeeze(-1)[:, :3] | |
| updated_all_obj_points.append(gs_to_pt3d(updated)) | |
| else: | |
| p_start = self.objs[obj_idx].particle_start | |
| p_end = self.objs[obj_idx].particle_end | |
| state = self.objs[obj_idx].solver.get_state(0) | |
| particles_now = state.pos[0, p_start:p_end].float() | |
| init_particles_gpu = self._init_particles_gpu.get(obj_idx) | |
| if init_particles_gpu is None: | |
| init_particles_gpu = torch.tensor( | |
| self.objs[obj_idx].init_particles, | |
| device=self.device, dtype=torch.float32, | |
| ) | |
| delta = particles_now - init_particles_gpu | |
| pc_delta = delta[self.closest_indices[obj_idx]].mean(dim=1) | |
| updated = self.fg_pcs_gs[obj_idx]["points"] + pc_delta | |
| updated_all_obj_points.append(gs_to_pt3d(updated)) | |
| return updated_all_obj_points | |
| def render_preview(self): | |
| frame_pil, _, _ = self.svr.render(frame_id=0, save=False, mask=False) | |
| return frame_pil | |
| def render_and_flow(self, updated_points, frame_id=None): | |
| """Render the current frame and compute optical flow.""" | |
| self.svr.update_fg_obj_info(updated_points) | |
| if frame_id is None: | |
| frame_id = self.step_count | |
| save_debug = self.config.get("debug", False) | |
| frame_pil, fg_mask, mesh_mask = self.svr.render( | |
| frame_id=frame_id, save=save_debug, mask=True, | |
| ) | |
| if self.svr._last_optical_flow is not None: | |
| flow_hw3 = self.svr._last_optical_flow | |
| flow_2hw = flow_hw3[..., :2].transpose(2, 0, 1) | |
| else: | |
| flow_2hw = np.zeros((2, 512, 512), dtype=np.float32) | |
| return frame_pil, flow_2hw, fg_mask, mesh_mask | |
| def save_debug_outputs(self, sim_frames=None): | |
| if not self.config.get("debug", False): | |
| return | |
| from simulation.utils import save_gif_from_image_folder, save_video_from_pil | |
| output = self._debug_output | |
| render_dir = self.svr.output_folder | |
| if self.debug_cam is not None and not self._debug_cam_failed: | |
| try: | |
| self.debug_cam.stop_recording( | |
| save_to_filename=str(output / "render_gs.mp4"), fps=10 | |
| ) | |
| except Exception as e: | |
| print(f"[debug] cam.stop_recording failed: {e}") | |
| if hasattr(self, '_debug_gs_frames') and self._debug_gs_frames.exists(): | |
| save_gif_from_image_folder( | |
| str(self._debug_gs_frames), str(output / "simulated_frames_gs.gif") | |
| ) | |
| svr_frames_dir = render_dir / "frames" | |
| if svr_frames_dir.exists(): | |
| save_gif_from_image_folder( | |
| str(svr_frames_dir), str(output / "simulated_frames.gif") | |
| ) | |
| svr_flow_dir = render_dir / "optical_flow" | |
| if svr_flow_dir.exists(): | |
| save_gif_from_image_folder( | |
| str(svr_flow_dir), str(output / "flow_image.gif") | |
| ) | |
| if sim_frames: | |
| save_video_from_pil( | |
| sim_frames, str(output / "simulated_frames.mp4"), fps=10 | |
| ) | |
| def reset(self): | |
| self.step_count = 0 | |
| if self.demo_case_handler is not None: | |
| self.demo_case_handler.reset_forces() | |
| self.scene.reset() | |
| self.case_handler.fix_particles() | |
| self.svr.previous_frame_data = None | |
| self.svr.optical_flow = np.array([]) | |
| self.svr._last_optical_flow = None | |
| self.svr.cache_bg = None | |
| self.svr._prev_fg_frags_idx = None | |
| self.svr._prev_fg_frags_dists = None | |
| def _map_pc_to_particles(self, obj_idx): | |
| sim_particles = torch.tensor( | |
| self.objs[obj_idx].init_particles, device=self.device | |
| ) | |
| K = 256 | |
| num_closest = self.config.get("closest_points_num", 5) | |
| chunks = torch.split(self.fg_pcs_gs[obj_idx]["points"], K) | |
| indices = [] | |
| for chunk in chunks: | |
| dists = torch.norm( | |
| chunk.unsqueeze(1) - sim_particles.unsqueeze(0), dim=2 | |
| ) | |
| indices.append( | |
| torch.topk(dists, k=num_closest, dim=1, largest=False)[1] | |
| ) | |
| del dists | |
| return torch.cat(indices) | |
| def _get_material(self, mt): | |
| c = self.config | |
| if mt == "rigid": | |
| return gs.materials.Rigid( | |
| rho=c.get("rigid_rho", 1000.0), | |
| friction=c.get("rigid_friction", 5.0), | |
| coup_friction=c.get("rigid_coup_friction", 5), | |
| coup_softness=c.get("rigid_coup_softness", 0.002), | |
| ), "visual" | |
| elif mt == "pbd_cloth": | |
| return gs.materials.PBD.Cloth( | |
| rho=c.get("pbd_rho", 4.0), | |
| static_friction=c.get("pbd_static_friction", 0.6), | |
| kinetic_friction=c.get("pbd_kinetic_friction", 0.35), | |
| stretch_compliance=c.get("pbd_stretch_compliance", 1e-7), | |
| bending_compliance=c.get("pbd_bending_compliance", 1e-5), | |
| stretch_relaxation=c.get("pbd_stretch_relaxation", 0.7), | |
| bending_relaxation=c.get("pbd_bending_relaxation", 0.1), | |
| air_resistance=c.get("pbd_air_resistance", 5e-3), | |
| ), "particle" | |
| elif mt == "pbd_elastic": | |
| return gs.materials.PBD.Elastic( | |
| rho=c.get("pbd_elastic_rho", 300.0), | |
| static_friction=c.get("pbd_elastic_static_friction", 0.15), | |
| kinetic_friction=c.get("pbd_elastic_kinetic_friction", 0.0), | |
| stretch_compliance=c.get("pbd_elastic_stretch_compliance", 0.0), | |
| bending_compliance=c.get("pbd_elastic_bending_compliance", 0.0), | |
| volume_compliance=c.get("pbd_elastic_volume_compliance", 0.0), | |
| stretch_relaxation=c.get("pbd_elastic_stretch_relaxation", 0.1), | |
| bending_relaxation=c.get("pbd_elastic_bending_relaxation", 0.1), | |
| volume_relaxation=c.get("pbd_elastic_volume_relaxation", 0.1), | |
| ), "particle" | |
| elif mt == "mpm_sand": | |
| return gs.materials.MPM.Sand( | |
| E=c.get("MPM_E", 1e6), nu=c.get("MPM_nu", 0.2), | |
| rho=c.get("MPM_rho", 1000.0), | |
| friction_angle=c.get("MPM_friction_angle", 45), | |
| ), "particle" | |
| elif mt == "mpm_elastic": | |
| return gs.materials.MPM.Elastic( | |
| E=c.get("MPM_E", 1e6), nu=c.get("MPM_nu", 0.2), | |
| rho=c.get("MPM_rho", 1000.0), | |
| ), "particle" | |
| elif mt == "mpm_liquid": | |
| return gs.materials.MPM.Liquid( | |
| E=c.get("MPM_E", 1e6), nu=c.get("MPM_nu", 0.2), | |
| rho=c.get("MPM_rho", 1000.0), | |
| ), "particle" | |
| elif mt == "mpm_snow": | |
| return gs.materials.MPM.Snow( | |
| E=c.get("MPM_E", 1e6), nu=c.get("MPM_nu", 0.2), | |
| rho=c.get("MPM_rho", 1000.0), | |
| ), "particle" | |
| elif mt == "pbd_liquid": | |
| return gs.materials.PBD.Liquid( | |
| rho=c.get("pbd_rho", 1000.0), | |
| density_relaxation=c.get("pbd_density_relaxation", 0.2), | |
| viscosity_relaxation=c.get("pbd_viscosity_relaxation", 0.1), | |
| ), "particle" | |
| elif mt == "pbd_particle": | |
| return gs.materials.PBD.Particle(), "particle" | |
| else: | |
| raise NotImplementedError(f"Material {mt} not supported") | |
| class _MinimalSVR: | |
| """Minimal point-cloud renderer with optical flow computation. | |
| Provides render() and update_fg_obj_info() with pre-loaded data. | |
| _proj_uv and save_optical_flow are inlined here to avoid importing | |
| SingleViewReconstructor (which pulls in SAM3D, MoGe, FluxInpainter). | |
| """ | |
| def __init__(self, config, camera, focal_length, bg_points, | |
| bg_points_colors, fg_pcs, device): | |
| self.config = config | |
| self.current_camera = camera | |
| self.init_focal_length = focal_length | |
| self.bg_points = bg_points | |
| self.bg_points_colors = bg_points_colors | |
| self.fg_pcs = fg_pcs | |
| self.device = device | |
| self.target_size = (512, 512) | |
| self.previous_frame_data = None | |
| self.optical_flow = np.array([]) | |
| self._last_optical_flow = None | |
| self._prev_fg_frags_idx = None | |
| self._prev_fg_frags_dists = None | |
| self.franka_mesh = None | |
| self.merge_mask = config.get("merge_mask", False) | |
| self.cache_bg = None | |
| self.fg_objects = [] | |
| self.output_folder = Path(config.get("output_folder", "/tmp/svr_render")) | |
| self.output_folder_frames = self.output_folder / "frames" | |
| self.output_folder_masks = self.output_folder / "masks" | |
| self.output_folder_optical_flow = self.output_folder / "optical_flow" | |
| if config.get("debug", False): | |
| for d in [self.output_folder_frames, self.output_folder_masks, | |
| self.output_folder_optical_flow]: | |
| d.mkdir(parents=True, exist_ok=True) | |
| self._build_cached_renderers() | |
| def _build_cached_renderers(self): | |
| from pytorch3d.renderer import ( | |
| PointsRenderer, PointsRasterizer, PointsRasterizationSettings, | |
| AlphaCompositor, | |
| ) | |
| cameras = self.current_camera | |
| image_size = self.target_size[0] | |
| fg_raster_settings = PointsRasterizationSettings( | |
| image_size=image_size, | |
| radius=self.config.get('fg_points_render_radius', 0.01), | |
| points_per_pixel=30, | |
| max_points_per_bin=20000, | |
| bin_size=0, | |
| ) | |
| self._fg_rasterizer = PointsRasterizer( | |
| cameras=cameras, raster_settings=fg_raster_settings, | |
| ) | |
| self._fg_renderer = PointsRenderer( | |
| rasterizer=self._fg_rasterizer, | |
| compositor=AlphaCompositor(), | |
| ) | |
| flow_raster_settings = PointsRasterizationSettings( | |
| image_size=image_size, | |
| radius=self.config.get('fg_points_render_radius', 0.01), | |
| points_per_pixel=30, | |
| max_points_per_bin=20000, | |
| bin_size=0, | |
| ) | |
| self._flow_rasterizer = PointsRasterizer( | |
| cameras=cameras, raster_settings=flow_raster_settings, | |
| ) | |
| self._flow_renderer = PointsRenderer( | |
| rasterizer=self._flow_rasterizer, | |
| compositor=AlphaCompositor(), | |
| ) | |
| def move_to_device(self, device): | |
| """Move all tensors to target device and rebuild renderers.""" | |
| from pytorch3d.renderer import PerspectiveCameras | |
| cam = self.current_camera | |
| self.current_camera = PerspectiveCameras( | |
| K=cam.K.to(device), | |
| R=cam.R.to(device), | |
| T=cam.T.to(device), | |
| in_ndc=False, | |
| image_size=((512, 512),), | |
| device=device, | |
| ) | |
| self.bg_points = self.bg_points.to(device) | |
| self.bg_points_colors = self.bg_points_colors.to(device) | |
| for pc in self.fg_pcs: | |
| pc['points'] = pc['points'].to(device) | |
| pc['colors'] = pc['colors'].to(device) | |
| self.device = device | |
| self.cache_bg = None # stale after device change; recomputed on next render | |
| self._build_cached_renderers() | |
| def update_fg_obj_info(self, all_obj_points): | |
| for idx, pts in enumerate(all_obj_points): | |
| self.fg_pcs[idx]["points"] = pts.clone() | |
| def _proj_uv(self, xyz, camera, image_size): | |
| """Project 3D points to 2D UV coordinates.""" | |
| device = xyz.device | |
| K_4x4 = camera.K[0] | |
| intr = K_4x4[:3, :3].clone() | |
| w2c = torch.eye(4).float().to(device) | |
| R_w2c = camera.R[0] | |
| T_w2c = camera.T[0] | |
| w2c[:3, :3] = R_w2c | |
| w2c[:3, 3] = T_w2c | |
| intr[2, 2] = 1.0 | |
| intr = intr.to(device) | |
| c_xyz = (w2c[:3, :3] @ xyz.T).T + w2c[:3, 3] | |
| i_xyz = (intr @ c_xyz.T).T | |
| uv = i_xyz[:, :2] / i_xyz[:, -1:].clip(1e-3) | |
| uv = image_size - uv | |
| return uv | |
| def save_optical_flow(self, optical_flow, valid_mask, frame_id): | |
| """Save optical flow visualization to disk (debug mode only).""" | |
| flow_x = optical_flow[:, :, 0].cpu().numpy() | |
| flow_y = optical_flow[:, :, 1].cpu().numpy() | |
| valid_mask_np = valid_mask.cpu().numpy() | |
| angle = np.arctan2(-flow_y, flow_x) | |
| hsv = np.zeros((optical_flow.shape[0], optical_flow.shape[1], 3), dtype=np.uint8) | |
| hsv[..., 0] = (angle + np.pi) / (2 * np.pi) * 179 | |
| hsv[..., 1] = 255 | |
| hsv[..., 2] = 255 | |
| hsv[~valid_mask_np] = 0 | |
| flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) | |
| import matplotlib.pyplot as plt | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) | |
| ax1.imshow(flow_rgb) | |
| ax1.set_title(f'Optical Flow Direction - Frame {frame_id}') | |
| ax1.axis('off') | |
| ax2.axis('off') | |
| plt.tight_layout() | |
| plt.savefig( | |
| f'{self.output_folder_optical_flow}/optical_flow_frame_{frame_id:04d}.png', | |
| dpi=150, bbox_inches='tight', | |
| ) | |
| plt.close() | |
| def render(self, render_bg=True, render_obj=True, render_mesh=True, | |
| frame_id=0, save=False, mask=True, compute_optical_flow=True): | |
| from pytorch3d.structures import Pointclouds | |
| from torchvision.transforms import ToPILImage | |
| cameras = self.current_camera | |
| image_size = self.target_size[0] | |
| # Background (cached after first render) | |
| if render_bg and self.cache_bg is None: | |
| from pytorch3d.renderer import ( | |
| PointsRenderer, PointsRasterizer, PointsRasterizationSettings, | |
| AlphaCompositor, | |
| ) | |
| bg_pc = Pointclouds( | |
| points=[self.bg_points], features=[self.bg_points_colors], | |
| ) | |
| bg_raster_settings = PointsRasterizationSettings( | |
| image_size=image_size, | |
| radius=self.config.get('bg_points_render_radius', 0.0001), | |
| points_per_pixel=30, | |
| ) | |
| bg_renderer = PointsRenderer( | |
| rasterizer=PointsRasterizer( | |
| cameras=cameras, raster_settings=bg_raster_settings, | |
| ), | |
| compositor=AlphaCompositor(), | |
| ) | |
| self.cache_bg = bg_renderer(bg_pc) | |
| if render_bg and self.cache_bg is not None: | |
| bg_image = self.cache_bg | |
| else: | |
| bg_image = torch.zeros(1, image_size, image_size, 3, device=self.device) | |
| base_rgb = bg_image[0].clone() | |
| final_rgb = base_rgb.clone() | |
| # Foreground | |
| all_fg_points = [] | |
| all_fg_colors = [] | |
| for pc_info in self.fg_pcs: | |
| all_fg_points.append(pc_info['points']) | |
| all_fg_colors.append(pc_info['colors']) | |
| combined_fg_points = torch.cat(all_fg_points, dim=0) | |
| combined_fg_colors = torch.cat(all_fg_colors, dim=0) | |
| flow_rendered_points = combined_fg_points.clone() | |
| combined_rgba = torch.cat([ | |
| combined_fg_colors, | |
| torch.ones_like(combined_fg_colors[..., :1]), | |
| ], dim=-1) | |
| fg_pc = Pointclouds(points=[combined_fg_points], features=[combined_rgba]) | |
| fragments = self._fg_rasterizer(fg_pc) | |
| r = self._fg_rasterizer.raster_settings.radius | |
| dists2 = fragments.dists.permute(0, 3, 1, 2) | |
| weights = 1 - dists2 / (r * r) | |
| fg_image = self._fg_renderer.compositor( | |
| fragments.idx.long().permute(0, 3, 1, 2), | |
| weights, | |
| fg_pc.features_packed().permute(1, 0), | |
| ) | |
| fg_image = fg_image.permute(0, 2, 3, 1) | |
| fg_rgb = fg_image[0, ..., :3] | |
| fg_alpha = fg_image[0, ..., 3:4] | |
| fg_depth = fragments.zbuf[0, ..., 0] | |
| fg_points_mask = torch.where( | |
| fg_alpha.squeeze(-1) > self.config['alpha_threshold'], 1.0, 0.0, | |
| ).unsqueeze(-1) | |
| fg_mask_2d = fg_points_mask.squeeze(-1) | |
| final_rgb = fg_rgb * fg_mask_2d.unsqueeze(-1) + final_rgb * (1.0 - fg_mask_2d.unsqueeze(-1)) | |
| # Mesh | |
| mesh_mask = torch.zeros(image_size, image_size, 1, dtype=torch.float32, device=self.device) | |
| if render_mesh and self.franka_mesh is not None: | |
| from pytorch3d.renderer import ( | |
| MeshRenderer, MeshRasterizer, SoftPhongShader, | |
| RasterizationSettings, BlendParams, | |
| ) | |
| from pytorch3d.structures import Meshes | |
| from pytorch3d.renderer.mesh.textures import TexturesVertex | |
| vertices = self.franka_mesh['vertices'] | |
| faces = self.franka_mesh['faces'] | |
| colors = self.franka_mesh['colors'] | |
| flow_rendered_points = torch.cat([flow_rendered_points, vertices], dim=0) | |
| if not isinstance(vertices, torch.Tensor): | |
| vertices = torch.tensor(vertices, dtype=torch.float32, device=self.device) | |
| if not isinstance(faces, torch.Tensor): | |
| faces = torch.tensor(faces, dtype=torch.long, device=self.device) | |
| if not isinstance(colors, torch.Tensor): | |
| colors = torch.tensor(colors, dtype=torch.float32, device=self.device) | |
| vertices = vertices.to(self.device) | |
| faces = faces.to(self.device) | |
| colors = colors.to(self.device) | |
| textures = TexturesVertex(verts_features=[colors]) | |
| combined_mesh = Meshes(verts=[vertices], faces=[faces], textures=textures) | |
| mesh_raster_settings = RasterizationSettings( | |
| image_size=image_size, blur_radius=0.0, faces_per_pixel=10, | |
| ) | |
| mesh_rasterizer = MeshRasterizer(cameras=cameras, raster_settings=mesh_raster_settings) | |
| mesh_renderer = MeshRenderer( | |
| rasterizer=mesh_rasterizer, | |
| shader=SoftPhongShader( | |
| device=self.device, cameras=cameras, | |
| blend_params=BlendParams(background_color=(0.0, 0.0, 0.0)), | |
| ), | |
| ) | |
| mesh_image = mesh_renderer(combined_mesh) | |
| mesh_rgb = mesh_image[0, ..., :3] | |
| mesh_alpha = mesh_image[0, ..., 3:4] | |
| mesh_fragments = mesh_rasterizer(combined_mesh) | |
| mesh_depth = mesh_fragments.zbuf[0, ..., 0] | |
| mesh_mask_2d = torch.where(mesh_alpha.squeeze(-1) > 0.01, 1.0, 0.0) | |
| fg_depth_valid = torch.where(fg_mask_2d > 0, fg_depth, torch.tensor(float('inf'), device=self.device)) | |
| mesh_depth_valid = torch.where(mesh_mask_2d > 0, mesh_depth, torch.tensor(float('inf'), device=self.device)) | |
| mesh_closer_bool = (mesh_depth_valid < fg_depth_valid) & (mesh_mask_2d > 0) | |
| mesh_closer_float = mesh_closer_bool.float() | |
| mesh_mask = mesh_closer_float.unsqueeze(-1) | |
| mesh_closer_3d = mesh_closer_float.unsqueeze(-1) | |
| final_rgb = mesh_rgb * mesh_closer_3d + final_rgb * (1.0 - mesh_closer_3d) | |
| fg_points_mask = torch.where( | |
| mesh_closer_bool.unsqueeze(-1), | |
| torch.zeros_like(fg_points_mask), fg_points_mask, | |
| ) | |
| # Optical flow | |
| if compute_optical_flow and self.previous_frame_data is not None: | |
| optical_flow = self._compute_optical_flow_pytorch3d_style( | |
| current_fg_points=flow_rendered_points, | |
| prev_fg_points=self.previous_frame_data['flow_rendered_points'], | |
| current_camera=cameras, | |
| prev_camera=self.previous_frame_data['camera'], | |
| image_size=image_size, | |
| frame_id=frame_id, | |
| prev_frags_idx=self._prev_fg_frags_idx, | |
| prev_frags_dists=self._prev_fg_frags_dists, | |
| ) | |
| flow_np = optical_flow.cpu().numpy() | |
| self._last_optical_flow = flow_np | |
| if self.config.get('debug', False): | |
| if self.optical_flow.size == 0: | |
| self.optical_flow = np.expand_dims(flow_np, 0) | |
| else: | |
| self.optical_flow = np.concatenate([ | |
| self.optical_flow, np.expand_dims(flow_np, 0), | |
| ]) | |
| if self.franka_mesh is None: | |
| self._prev_fg_frags_idx = fragments.idx | |
| self._prev_fg_frags_dists = fragments.dists | |
| else: | |
| self._prev_fg_frags_idx = None | |
| self._prev_fg_frags_dists = None | |
| if save: | |
| if mask: | |
| points_mask_path = self.output_folder_masks / f"points_mask_{frame_id:04d}.png" | |
| points_mask_to_save = fg_points_mask.squeeze(2) if fg_points_mask.dim() == 3 else fg_points_mask | |
| ToPILImage()(points_mask_to_save.unsqueeze(0).clamp(0, 1).cpu()).save(points_mask_path.as_posix()) | |
| mesh_mask_path = self.output_folder_masks / f"mesh_mask_{frame_id:04d}.png" | |
| mesh_mask_to_save = mesh_mask.squeeze(2) if mesh_mask.dim() == 3 else mesh_mask | |
| ToPILImage()(mesh_mask_to_save.unsqueeze(0).clamp(0, 1).cpu()).save(mesh_mask_path.as_posix()) | |
| image_pil = ToPILImage()(final_rgb.permute(2, 0, 1).clamp(0, 1).cpu()) | |
| image_path = self.output_folder_frames / f"frame_{frame_id:04d}.png" | |
| image_pil.save(image_path.as_posix()) | |
| else: | |
| image_pil = ToPILImage()(final_rgb.permute(2, 0, 1).clamp(0, 1).cpu()) | |
| self.previous_frame_data = { | |
| 'camera': cameras, | |
| 'bg_points': self.bg_points, | |
| 'flow_rendered_points': flow_rendered_points, | |
| } | |
| return image_pil, fg_points_mask, mesh_mask | |
| def _compute_optical_flow_pytorch3d_style(self, current_fg_points, prev_fg_points, | |
| current_camera, prev_camera, | |
| image_size=512, frame_id=0, | |
| prev_frags_idx=None, | |
| prev_frags_dists=None): | |
| from pytorch3d.structures import Pointclouds | |
| if current_fg_points.shape[0] > prev_fg_points.shape[0]: | |
| current_fg_points = current_fg_points[:prev_fg_points.shape[0]] | |
| elif prev_fg_points.shape[0] > current_fg_points.shape[0]: | |
| prev_more = prev_fg_points[-(prev_fg_points.shape[0] - current_fg_points.shape[0]):] | |
| current_fg_points = torch.cat([current_fg_points, prev_more], dim=0) | |
| current_uv = self._proj_uv(current_fg_points, current_camera, image_size) | |
| prev_uv = self._proj_uv(prev_fg_points, prev_camera, image_size) | |
| delta_uv = current_uv - prev_uv | |
| flow_colors = torch.cat([delta_uv, torch.zeros_like(delta_uv[:, :1])], dim=-1) | |
| xy_flow = flow_colors[:, :2] | |
| magnitude = torch.sqrt(xy_flow[:, 0] ** 2 + xy_flow[:, 1] ** 2) | |
| zero_flow_mask = magnitude < 1e-4 | |
| min_val = xy_flow.min() | |
| max_val = xy_flow.max() | |
| if max_val - min_val > 1e-4: | |
| flow_colors[:, :2] = 0.1 + (xy_flow - min_val) / (max_val - min_val) * 0.8 | |
| flow_colors[zero_flow_mask, :2] = 0.0 | |
| else: | |
| flow_colors[:, :2] = 0.5 | |
| flow_colors = torch.clamp(flow_colors, 0, 1) | |
| flow_rgba = torch.cat([flow_colors, torch.ones_like(flow_colors[..., :1])], dim=-1) | |
| if prev_frags_idx is not None and prev_frags_dists is not None: | |
| r = self._fg_rasterizer.raster_settings.radius | |
| dists2 = prev_frags_dists.permute(0, 3, 1, 2) | |
| prev_weights = 1 - dists2 / (r * r) | |
| flow_image_raw = self._fg_renderer.compositor( | |
| prev_frags_idx.long().permute(0, 3, 1, 2), | |
| prev_weights, | |
| flow_rgba.permute(1, 0), | |
| ) | |
| flow_image = flow_image_raw.permute(0, 2, 3, 1) | |
| else: | |
| point_cloud = Pointclouds(points=[prev_fg_points], features=[flow_rgba]) | |
| flow_image = self._flow_renderer(point_cloud) | |
| flow_alpha = flow_image[0, :, :, 3] | |
| valid_mask = flow_alpha > self.config['alpha_threshold'] | |
| optical_flow = torch.zeros(image_size, image_size, 3, device=self.device) | |
| if valid_mask.sum() > 0 and max_val - min_val > 1e-4: | |
| rendered_flow = flow_image[0, :, :, :2][valid_mask] | |
| zero_pixels = torch.all(rendered_flow < 0.05, dim=-1) | |
| normal_pixels = ~zero_pixels | |
| full_flow = torch.zeros_like(rendered_flow) | |
| if normal_pixels.sum() > 0: | |
| full_flow[normal_pixels] = ( | |
| (rendered_flow[normal_pixels] - 0.1) / 0.8 * (max_val - min_val) + min_val | |
| ) | |
| optical_flow[:, :, :2][valid_mask] = full_flow | |
| if self.config.get('debug', False): | |
| meaningful_mask = valid_mask.clone() | |
| valid_coords = torch.where(valid_mask) | |
| meaningful_mask[valid_coords[0][zero_pixels], valid_coords[1][zero_pixels]] = False | |
| self.save_optical_flow(optical_flow, meaningful_mask, frame_id) | |
| return optical_flow | |
| def num_fg_objects(self): | |
| return len(self.fg_pcs) | |