Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import os | |
| # not ideal to put that here | |
| os.environ["CUDA_HOME"] = os.environ["CONDA_PREFIX"] | |
| os.environ["LIDRA_SKIP_INIT"] = "true" | |
| import sys | |
| from typing import Union, Optional, List, Callable | |
| import numpy as np | |
| from PIL import Image | |
| from omegaconf import OmegaConf, DictConfig, ListConfig | |
| from hydra.utils import instantiate, get_method | |
| import torch | |
| import math | |
| import utils3d | |
| import shutil | |
| import subprocess | |
| import seaborn as sns | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from copy import deepcopy | |
| from kaolin.visualize import IpyTurntableVisualizer | |
| from kaolin.render.camera import Camera, CameraExtrinsics, PinholeIntrinsics | |
| import builtins | |
| from pytorch3d.transforms import quaternion_multiply, quaternion_invert | |
| import sam3d_objects # REMARK(Pierre) : do not remove this import | |
| from sam3d_objects.pipeline.inference_pipeline_pointmap import InferencePipelinePointMap | |
| from sam3d_objects.model.backbone.tdfy_dit.utils import render_utils | |
| from sam3d_objects.utils.visualization import SceneVisualizer | |
| __all__ = ["Inference"] | |
| WHITELIST_FILTERS = [ | |
| lambda target: target.split(".", 1)[0] in {"sam3d_objects", "torch", "torchvision", "moge"}, | |
| ] | |
| BLACKLIST_FILTERS = [ | |
| lambda target: get_method(target) | |
| in { | |
| builtins.exec, | |
| builtins.eval, | |
| builtins.__import__, | |
| os.kill, | |
| os.system, | |
| os.putenv, | |
| os.remove, | |
| os.removedirs, | |
| os.rmdir, | |
| os.fchdir, | |
| os.setuid, | |
| os.fork, | |
| os.forkpty, | |
| os.killpg, | |
| os.rename, | |
| os.renames, | |
| os.truncate, | |
| os.replace, | |
| os.unlink, | |
| os.fchmod, | |
| os.fchown, | |
| os.chmod, | |
| os.chown, | |
| os.chroot, | |
| os.fchdir, | |
| os.lchown, | |
| os.getcwd, | |
| os.chdir, | |
| shutil.rmtree, | |
| shutil.move, | |
| shutil.chown, | |
| subprocess.Popen, | |
| builtins.help, | |
| }, | |
| ] | |
| class Inference: | |
| # public facing inference API | |
| # only put publicly exposed arguments here | |
| def __init__(self, config_file: str, compile: bool = False): | |
| # load inference pipeline | |
| config = OmegaConf.load(config_file) | |
| config.rendering_engine = "pytorch3d" # overwrite to disable nvdiffrast | |
| config.compile_model = compile | |
| config.workspace_dir = os.path.dirname(config_file) | |
| check_hydra_safety(config, WHITELIST_FILTERS, BLACKLIST_FILTERS) | |
| self._pipeline: InferencePipelinePointMap = instantiate(config) | |
| def merge_mask_to_rgba(self, image, mask): | |
| mask = mask.astype(np.uint8) * 255 | |
| mask = mask[..., None] | |
| # embed mask in alpha channel | |
| rgba_image = np.concatenate([image[..., :3], mask], axis=-1) | |
| return rgba_image | |
| def __call__( | |
| self, | |
| image: Union[Image.Image, np.ndarray], | |
| mask: Optional[Union[None, Image.Image, np.ndarray]], | |
| seed: Optional[int] = None, | |
| pointmap=None, | |
| ) -> dict: | |
| image = self.merge_mask_to_rgba(image, mask) | |
| return self._pipeline.run( | |
| image, | |
| None, | |
| seed, | |
| stage1_only=False, | |
| with_mesh_postprocess=False, | |
| with_texture_baking=False, | |
| with_layout_postprocess=True, | |
| use_vertex_color=True, | |
| stage1_inference_steps=None, | |
| pointmap=pointmap, | |
| ) | |
| def _yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): | |
| is_list = isinstance(yaws, list) | |
| if not is_list: | |
| yaws = [yaws] | |
| pitchs = [pitchs] | |
| if not isinstance(rs, list): | |
| rs = [rs] * len(yaws) | |
| if not isinstance(fovs, list): | |
| fovs = [fovs] * len(yaws) | |
| extrinsics = [] | |
| intrinsics = [] | |
| for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): | |
| fov = torch.deg2rad(torch.tensor(float(fov))).cuda() | |
| yaw = torch.tensor(float(yaw)).cuda() | |
| pitch = torch.tensor(float(pitch)).cuda() | |
| orig = ( | |
| torch.tensor( | |
| [ | |
| torch.sin(yaw) * torch.cos(pitch), | |
| torch.sin(pitch), | |
| torch.cos(yaw) * torch.cos(pitch), | |
| ] | |
| ).cuda() | |
| * r | |
| ) | |
| extr = utils3d.torch.extrinsics_look_at( | |
| orig, | |
| torch.tensor([0, 0, 0]).float().cuda(), | |
| torch.tensor([0, 1, 0]).float().cuda(), | |
| ) | |
| intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) | |
| extrinsics.append(extr) | |
| intrinsics.append(intr) | |
| if not is_list: | |
| extrinsics = extrinsics[0] | |
| intrinsics = intrinsics[0] | |
| return extrinsics, intrinsics | |
| def render_video( | |
| sample, | |
| resolution=512, | |
| bg_color=(0, 0, 0), | |
| num_frames=300, | |
| r=2.0, | |
| fov=40, | |
| pitch_deg=0, | |
| yaw_start_deg=-90, | |
| **kwargs, | |
| ): | |
| yaws = ( | |
| torch.linspace(0, 2 * torch.pi, num_frames) + math.radians(yaw_start_deg) | |
| ).tolist() | |
| pitch = [math.radians(pitch_deg)] * num_frames | |
| extr, intr = _yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) | |
| return render_utils.render_frames( | |
| sample, | |
| extr, | |
| intr, | |
| {"resolution": resolution, "bg_color": bg_color, "backend": "gsplat"}, | |
| **kwargs, | |
| ) | |
| def ready_gaussian_for_video_rendering(scene_gs, in_place=False, fix_alignment=False): | |
| if fix_alignment: | |
| scene_gs = _fix_gaussian_alignment(scene_gs, in_place=in_place) | |
| scene_gs = normalized_gaussian(scene_gs, in_place=fix_alignment) | |
| return scene_gs | |
| def _fix_gaussian_alignment(scene_gs, in_place=False): | |
| if not in_place: | |
| scene_gs = deepcopy(scene_gs) | |
| device = scene_gs._xyz.device | |
| dtype = scene_gs._xyz.dtype | |
| scene_gs._xyz = ( | |
| scene_gs._xyz | |
| ) | |
| return scene_gs | |
| def normalized_gaussian(scene_gs, in_place=False, outlier_percentile=None): | |
| if not in_place: | |
| scene_gs = deepcopy(scene_gs) | |
| orig_xyz = scene_gs.get_xyz | |
| orig_scale = scene_gs.get_scaling | |
| active_mask = (scene_gs.get_opacity > 0.9).squeeze() | |
| inv_scale = ( | |
| orig_xyz[active_mask].max(dim=0)[0] - orig_xyz[active_mask].min(dim=0)[0] | |
| ).max() | |
| norm_scale = orig_scale / inv_scale | |
| norm_xyz = orig_xyz / inv_scale | |
| if outlier_percentile is None: | |
| lower_bound_xyz = torch.min(norm_xyz[active_mask], dim=0)[0] | |
| upper_bound_xyz = torch.max(norm_xyz[active_mask], dim=0)[0] | |
| else: | |
| lower_bound_xyz = torch.quantile( | |
| norm_xyz[active_mask], | |
| outlier_percentile, | |
| dim=0, | |
| ) | |
| upper_bound_xyz = torch.quantile( | |
| norm_xyz[active_mask], | |
| 1.0 - outlier_percentile, | |
| dim=0, | |
| ) | |
| center = (lower_bound_xyz + upper_bound_xyz) / 2 | |
| norm_xyz = norm_xyz - center | |
| scene_gs.from_xyz(norm_xyz) | |
| scene_gs.mininum_kernel_size /= inv_scale.item() | |
| scene_gs.from_scaling(norm_scale) | |
| return scene_gs | |
| def make_scene(*outputs, in_place=False): | |
| if not in_place: | |
| outputs = [deepcopy(output) for output in outputs] | |
| all_outs = [] | |
| minimum_kernel_size = float("inf") | |
| for output in outputs: | |
| # move gaussians to scene frame of reference | |
| PC = SceneVisualizer.object_pointcloud( | |
| points_local=output["gaussian"][0].get_xyz.unsqueeze(0), | |
| quat_l2c=output["rotation"], | |
| trans_l2c=output["translation"], | |
| scale_l2c=output["scale"], | |
| ) | |
| output["gaussian"][0].from_xyz(PC.points_list()[0]) | |
| # must ... ROTATE | |
| output["gaussian"][0].from_rotation( | |
| quaternion_multiply( | |
| quaternion_invert(output["rotation"]), | |
| output["gaussian"][0].get_rotation, | |
| ) | |
| ) | |
| scale = output["gaussian"][0].get_scaling | |
| adjusted_scale = scale * output["scale"] | |
| assert ( | |
| output["scale"][0, 0].item() | |
| == output["scale"][0, 1].item() | |
| == output["scale"][0, 2].item() | |
| ) | |
| output["gaussian"][0].mininum_kernel_size *= output["scale"][0, 0].item() | |
| adjusted_scale = torch.maximum( | |
| adjusted_scale, | |
| torch.tensor( | |
| output["gaussian"][0].mininum_kernel_size * 1.1, | |
| device=adjusted_scale.device, | |
| ), | |
| ) | |
| output["gaussian"][0].from_scaling(adjusted_scale) | |
| minimum_kernel_size = min( | |
| minimum_kernel_size, | |
| output["gaussian"][0].mininum_kernel_size, | |
| ) | |
| all_outs.append(output) | |
| # merge gaussians | |
| scene_gs = all_outs[0]["gaussian"][0] | |
| scene_gs.mininum_kernel_size = minimum_kernel_size | |
| for out in all_outs[1:]: | |
| out_gs = out["gaussian"][0] | |
| scene_gs._xyz = torch.cat([scene_gs._xyz, out_gs._xyz], dim=0) | |
| scene_gs._features_dc = torch.cat( | |
| [scene_gs._features_dc, out_gs._features_dc], dim=0 | |
| ) | |
| scene_gs._scaling = torch.cat([scene_gs._scaling, out_gs._scaling], dim=0) | |
| scene_gs._rotation = torch.cat([scene_gs._rotation, out_gs._rotation], dim=0) | |
| scene_gs._opacity = torch.cat([scene_gs._opacity, out_gs._opacity], dim=0) | |
| return scene_gs | |
| def check_target( | |
| target: str, | |
| whitelist_filters: List[Callable], | |
| blacklist_filters: List[Callable], | |
| ): | |
| if any(filt(target) for filt in whitelist_filters): | |
| if not any(filt(target) for filt in blacklist_filters): | |
| return | |
| raise RuntimeError( | |
| f"target '{target}' is not allowed to be hydra instantiated, if this is a mistake, please do modify the whitelist_filters / blacklist_filters" | |
| ) | |
| def check_hydra_safety( | |
| config: DictConfig, | |
| whitelist_filters: List[Callable], | |
| blacklist_filters: List[Callable], | |
| ): | |
| to_check = [config] | |
| while len(to_check) > 0: | |
| node = to_check.pop() | |
| if isinstance(node, DictConfig): | |
| to_check.extend(list(node.values())) | |
| if "_target_" in node: | |
| check_target(node["_target_"], whitelist_filters, blacklist_filters) | |
| elif isinstance(node, ListConfig): | |
| to_check.extend(list(node)) | |
| def load_image(path): | |
| image = Image.open(path) | |
| image = np.array(image) | |
| image = image.astype(np.uint8) | |
| return image | |
| def load_mask(path): | |
| mask = load_image(path) | |
| mask = mask > 0 | |
| if mask.ndim == 3: | |
| mask = mask[..., -1] | |
| return mask | |
| def load_single_mask(folder_path, index=0, extension=".png"): | |
| masks = load_masks(folder_path, [index], extension) | |
| return masks[0] | |
| def load_masks(folder_path, indices_list=None, extension=".png"): | |
| masks = [] | |
| indices_list = [] if indices_list is None else list(indices_list) | |
| if not len(indices_list) > 0: # get all all masks if not provided | |
| idx = 0 | |
| while os.path.exists(os.path.join(folder_path, f"{idx}{extension}")): | |
| indices_list.append(idx) | |
| idx += 1 | |
| for idx in indices_list: | |
| mask_path = os.path.join(folder_path, f"{idx}{extension}") | |
| assert os.path.exists(mask_path), f"Mask path {mask_path} does not exist" | |
| mask = load_mask(mask_path) | |
| masks.append(mask) | |
| return masks | |
| def display_image(image, masks=None): | |
| def imshow(image, ax): | |
| ax.axis("off") | |
| ax.imshow(image) | |
| grid = (1, 1) if masks is None else (2, 2) | |
| fig, axes = plt.subplots(*grid) | |
| if masks is not None: | |
| mask_colors = sns.color_palette("husl", len(masks)) | |
| black_image = np.zeros_like(image[..., :3], dtype=float) # background | |
| mask_display = np.copy(black_image) | |
| mask_union = np.zeros_like(image[..., :3]) | |
| for i, mask in enumerate(masks): | |
| mask_display[mask] = mask_colors[i] | |
| mask_union |= mask[..., None] if mask.ndim == 2 else mask | |
| imshow(black_image, axes[0, 1]) | |
| imshow(mask_display, axes[1, 0]) | |
| imshow(image * mask_union, axes[1, 1]) | |
| image_axe = axes if masks is None else axes[0, 0] | |
| imshow(image, image_axe) | |
| fig.tight_layout(pad=0) | |
| fig.show() | |
| def interactive_visualizer(ply_path): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 3D Gaussian Splatting (black-screen loading might take a while)") | |
| gr.Model3D( | |
| value=ply_path, # splat file | |
| label="3D Scene", | |
| ) | |
| demo.launch(share=True) | |