Spaces:
Paused
Paused
| import os | |
| import argparse | |
| import torch | |
| import trimesh | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import gradio as gr | |
| from omegaconf import OmegaConf | |
| import sys | |
| sys.path.append('./src') | |
| from StructDiffusion.data.semantic_arrangement_language_demo import SemanticArrangementDataset | |
| from StructDiffusion.language.tokenizer import Tokenizer | |
| from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel, PairwiseCollisionModel | |
| from StructDiffusion.diffusion.sampler import Sampler, SamplerV2 | |
| from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses | |
| from StructDiffusion.utils.files import get_checkpoint_path_from_dir | |
| from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh, get_trimesh_scene_with_table | |
| import StructDiffusion.utils.transformations as tra | |
| from StructDiffusion.language.sentence_encoder import SentenceBertEncoder | |
| import StructDiffusion.utils.transformations as tra | |
| def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct): | |
| device = obj_xyzs.device | |
| # obj_xyzs: B, N, P, 3 or 6 | |
| # struct_pose: B, 1, 4, 4 | |
| # pc_poses_in_struct: B, N, 4, 4 | |
| B, N, _, _ = pc_poses_in_struct.shape | |
| _, _, P, _ = obj_xyzs.shape | |
| current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4 | |
| # print(torch.mean(obj_xyzs, dim=2).shape) | |
| current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4 | |
| current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4 | |
| struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4 | |
| struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4 | |
| pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4 | |
| goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4 | |
| # print("goal pc poses") | |
| # print(goal_pc_pose) | |
| goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4 | |
| # # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix | |
| # transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2)) | |
| # new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3 | |
| # new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3]) | |
| # a verision that does not rely on pytorch3d | |
| new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1)[:, :, :3] # B x N, P, 3 | |
| new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4 | |
| new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3 | |
| # put it back to B, N, P, 3 | |
| obj_xyzs[:, :, :, :3] = new_obj_xyzs.reshape(B, N, P, -1) | |
| return obj_xyzs | |
| class Infer_Wrapper: | |
| def __init__(self, args, cfg): | |
| self.num_pts = cfg.DATASET.num_pts | |
| # load | |
| pl.seed_everything(args.eval_random_seed) | |
| self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) | |
| diffusion_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.diffusion_checkpoint_id, "checkpoints")) | |
| collision_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.collision_checkpoint_id, "checkpoints")) | |
| self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir) | |
| # override ignore_rgb for visualization | |
| cfg.DATASET.ignore_rgb = False | |
| self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET) | |
| self.sampler = SamplerV2(ConditionalPoseDiffusionModel, diffusion_checkpoint_path, | |
| PairwiseCollisionModel, collision_checkpoint_path, self.device) | |
| self.sentence_encoder = SentenceBertEncoder() | |
| self.session_id_to_obj_xyzs = {} | |
| def visualize_scene(self, di, session_id): | |
| raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True) | |
| language_command = raw_datum["template_sentence"] | |
| obj_xyz = raw_datum["pcs"] | |
| scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True) | |
| scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2)) | |
| scene_filename = "./tmp_data/input_scene_{}.glb".format(session_id) | |
| scene.export(scene_filename) | |
| return language_command, scene_filename | |
| def build_scene(self, mesh_filename_1, x_1, y_1, z_1, ai_1, aj_1, ak_1, scale_1, | |
| mesh_filename_2, x_2, y_2, z_2, ai_2, aj_2, ak_2, scale_2, | |
| mesh_filename_3, x_3, y_3, z_3, ai_3, aj_3, ak_3, scale_3, | |
| mesh_filename_4, x_4, y_4, z_4, ai_4, aj_4, ak_4, scale_4, | |
| mesh_filename_5, x_5, y_5, z_5, ai_5, aj_5, ak_5, scale_5, session_id): | |
| object_list = [(mesh_filename_1, x_1, y_1, z_1, ai_1, aj_1, ak_1, scale_1), | |
| (mesh_filename_2, x_2, y_2, z_2, ai_2, aj_2, ak_2, scale_2), | |
| (mesh_filename_3, x_3, y_3, z_3, ai_3, aj_3, ak_3, scale_3), | |
| (mesh_filename_4, x_4, y_4, z_4, ai_4, aj_4, ak_4, scale_4), | |
| (mesh_filename_5, x_5, y_5, z_5, ai_5, aj_5, ak_5, scale_5)] | |
| scene = get_trimesh_scene_with_table() | |
| obj_xyzs = [] | |
| for mesh_filename, x, y, z, ai, aj, ak, scale in object_list: | |
| if mesh_filename is None: | |
| continue | |
| obj_mesh = trimesh.load(mesh_filename) | |
| obj_mesh.apply_scale(scale) | |
| z_min = obj_mesh.bounds[0, 2] | |
| tform = tra.euler_matrix(ai, aj, ak) | |
| tform[:3, 3] = [x, y, z - z_min] | |
| obj_mesh.apply_transform(tform) | |
| obj_xyz = obj_mesh.sample(self.num_pts) | |
| obj = trimesh.PointCloud(obj_xyz) | |
| scene.add_geometry(obj) | |
| obj_xyzs.append(obj_xyz) | |
| self.session_id_to_obj_xyzs[session_id] = obj_xyzs | |
| # scene.show() | |
| # obj_file = "/home/weiyu/data_drive/StructDiffusion/housekeep_custom_handpicked_small/visual/book_Eat_to_Live_The_Amazing_NutrientRich_Program_for_Fast_and_Sustained_Weight_Loss_Revised_Edition_Book_L/model.obj" | |
| # obj = trimesh.load(obj_file) | |
| # | |
| # scene = get_trimesh_scene_with_table() | |
| # scene.add_geometry(obj) | |
| # | |
| # scene.show() | |
| # raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True) | |
| # language_command = raw_datum["template_sentence"] | |
| # | |
| # obj_xyz = raw_datum["pcs"] | |
| # scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], | |
| # return_scene=True) | |
| scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi / 2)) | |
| scene_filename = "./tmp_data/input_scene_{}.glb".format(session_id) | |
| scene.export(scene_filename) | |
| return scene_filename | |
| # return language_command, scene_filename | |
| def infer(self, language_command, session_id, progress=gr.Progress()): | |
| obj_xyzs = self.session_id_to_obj_xyzs[session_id] | |
| sentence_embedding = self.sentence_encoder.encode([language_command]).flatten() | |
| raw_datum = self.dataset.build_data_from_xyzs(obj_xyzs, sentence_embedding) | |
| datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer, use_sentence_embedding=True) | |
| batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True) | |
| num_poses = raw_datum["num_goal_poses"] | |
| struct_pose, pc_poses_in_struct = self.sampler.sample(batch, num_poses, args.num_elites, args.discriminator_batch_size) | |
| new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"][:args.num_elites], struct_pose, pc_poses_in_struct) | |
| # vis | |
| vis_obj_xyzs = new_obj_xyzs[:3] | |
| if torch.is_tensor(vis_obj_xyzs): | |
| if vis_obj_xyzs.is_cuda: | |
| vis_obj_xyzs = vis_obj_xyzs.detach().cpu() | |
| vis_obj_xyzs = vis_obj_xyzs.numpy() | |
| vis_obj_xyz = vis_obj_xyzs[0] | |
| # scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True) | |
| scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], obj_rgbs=None, return_scene=True) | |
| # scene.show() | |
| scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2)) | |
| scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id) | |
| scene.export(scene_filename) | |
| # pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb" | |
| # scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb" | |
| # | |
| # vis_obj_xyz = vis_obj_xyz.reshape(-1, 6) | |
| # vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1)) | |
| # vis_pc.export(pc_filename) | |
| # | |
| # scene = trimesh.Scene() | |
| # # add the coordinate frame first | |
| # # geom = trimesh.creation.axis(0.01) | |
| # # scene.add_geometry(geom) | |
| # table = trimesh.creation.box(extents=[1.0, 1.0, 0.02]) | |
| # table.apply_translation([0.5, 0, -0.01]) | |
| # table.visual.vertex_colors = [150, 111, 87, 125] | |
| # scene.add_geometry(table) | |
| # # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0]) | |
| # # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1) | |
| # # bounds.apply_translation([0, 0, 0]) | |
| # # bounds.visual.vertex_colors = [30, 30, 30, 30] | |
| # # scene.add_geometry(bounds) | |
| # # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481], | |
| # # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997], | |
| # # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951], | |
| # # [0.0, 0.0, 0.0, 1.0]]) | |
| # # RT_4x4 = np.linalg.inv(RT_4x4) | |
| # # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1]) | |
| # # scene.camera_transform = RT_4x4 | |
| # | |
| # mesh_list = trimesh.util.concatenate(scene.dump()) | |
| # print(mesh_list) | |
| # trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj') | |
| return scene_filename | |
| args = OmegaConf.create() | |
| args.base_config_file = "./configs/base.yaml" | |
| args.config_file = "./configs/conditional_pose_diffusion_language.yaml" | |
| args.diffusion_checkpoint_id = "ConditionalPoseDiffusionLanguage" | |
| args.collision_checkpoint_id = "CollisionDiscriminator" | |
| args.eval_random_seed = 42 | |
| args.num_samples = 50 | |
| args.num_elites = 3 | |
| args.discriminator_batch_size = 10 | |
| base_cfg = OmegaConf.load(args.base_config_file) | |
| cfg = OmegaConf.load(args.config_file) | |
| cfg = OmegaConf.merge(base_cfg, cfg) | |
| infer_wrapper = Infer_Wrapper(args, cfg) | |
| # # version 1 | |
| # demo = gr.Blocks(theme=gr.themes.Soft()) | |
| # with demo: | |
| # gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>") | |
| # # font-size:18px | |
| # gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>") | |
| # | |
| # session_id = gr.State(value=np.random.randint(0, 1000)) | |
| # data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0) | |
| # input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene") | |
| # language_command = gr.Textbox(label="Input Language Command") | |
| # output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure") | |
| # | |
| # b1 = gr.Button("Show Input Language and Scene") | |
| # b2 = gr.Button("Generate 3D Structure") | |
| # | |
| # b1.click(infer_wrapper.visualize_scene, inputs=[data_selection, session_id], outputs=[language_command, input_scene]) | |
| # b2.click(infer_wrapper.infer, inputs=[data_selection, session_id], outputs=output_scene) | |
| # | |
| # demo.queue(concurrency_count=10) | |
| # demo.launch() | |
| # version 1 | |
| # demo = gr.Blocks(theme=gr.themes.Soft()) | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>") | |
| # font-size:18px | |
| gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>") | |
| session_id = gr.State(value=np.random.randint(0, 1000)) | |
| with gr.Tab("Object 1"): | |
| with gr.Column(scale=1, min_width=600): | |
| mesh_filename_1 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object") | |
| with gr.Row(): | |
| x_1 = gr.Slider(0, 1, label="x") | |
| y_1 = gr.Slider(-0.5, 0.5, label="y") | |
| z_1 = gr.Slider(0, 0.5, label="z") | |
| with gr.Row(): | |
| ai_1 = gr.Slider(0, np.pi * 2, label="roll") | |
| aj_1 = gr.Slider(0, np.pi * 2, label="pitch") | |
| ak_1 = gr.Slider(0, np.pi * 2, label="yaw") | |
| scale_1 = gr.Slider(0, 1) | |
| with gr.Tab("Object 2"): | |
| with gr.Column(scale=1, min_width=600): | |
| mesh_filename_2 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object") | |
| with gr.Row(): | |
| x_2 = gr.Slider(0, 1, label="x") | |
| y_2 = gr.Slider(-0.5, 0.5, label="y") | |
| z_2 = gr.Slider(0, 0.5, label="z") | |
| with gr.Row(): | |
| ai_2 = gr.Slider(0, np.pi * 2, label="roll") | |
| aj_2 = gr.Slider(0, np.pi * 2, label="pitch") | |
| ak_2 = gr.Slider(0, np.pi * 2, label="yaw") | |
| scale_2 = gr.Slider(0, 1) | |
| with gr.Tab("Object 3"): | |
| with gr.Column(scale=1, min_width=600): | |
| mesh_filename_3 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object") | |
| with gr.Row(): | |
| x_3 = gr.Slider(0, 1, label="x") | |
| y_3 = gr.Slider(-0.5, 0.5, label="y") | |
| z_3 = gr.Slider(0, 0.5, label="z") | |
| with gr.Row(): | |
| ai_3 = gr.Slider(0, np.pi * 2, label="roll") | |
| aj_3 = gr.Slider(0, np.pi * 2, label="pitch") | |
| ak_3 = gr.Slider(0, np.pi * 2, label="yaw") | |
| scale_3 = gr.Slider(0, 1) | |
| with gr.Tab("Object 4"): | |
| with gr.Column(scale=1, min_width=600): | |
| mesh_filename_4 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object") | |
| with gr.Row(): | |
| x_4 = gr.Slider(0, 1, label="x") | |
| y_4 = gr.Slider(-0.5, 0.5, label="y") | |
| z_4 = gr.Slider(0, 0.5, label="z") | |
| with gr.Row(): | |
| ai_4 = gr.Slider(0, np.pi * 2, label="roll") | |
| aj_4 = gr.Slider(0, np.pi * 2, label="pitch") | |
| ak_4 = gr.Slider(0, np.pi * 2, label="yaw") | |
| scale_4 = gr.Slider(0, 1) | |
| with gr.Tab("Object 5"): | |
| with gr.Column(scale=1, min_width=600): | |
| mesh_filename_5 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object") | |
| with gr.Row(): | |
| x_5 = gr.Slider(0, 1, label="x") | |
| y_5 = gr.Slider(-0.5, 0.5, label="y") | |
| z_5 = gr.Slider(0, 0.5, label="z") | |
| with gr.Row(): | |
| ai_5 = gr.Slider(0, np.pi * 2, label="roll") | |
| aj_5 = gr.Slider(0, np.pi * 2, label="pitch") | |
| ak_5 = gr.Slider(0, np.pi * 2, label="yaw") | |
| scale_5 = gr.Slider(0, 1) | |
| b1 = gr.Button("Build Initial Scene") | |
| initial_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Initial 3D Scene") | |
| language_command = gr.Textbox(label="Input Language Command") | |
| b2 = gr.Button("Generate 3D Structure") | |
| output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure") | |
| # data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0) | |
| # input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene") | |
| # language_command = gr.Textbox(label="Input Language Command") | |
| # output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure") | |
| # | |
| # b1 = gr.Button("Show Input Language and Scene") | |
| # b2 = gr.Button("Generate 3D Structure") | |
| b1.click(infer_wrapper.build_scene, inputs=[mesh_filename_1, x_1, y_1, z_1, ai_1, aj_1, ak_1, scale_1, | |
| mesh_filename_2, x_2, y_2, z_2, ai_2, aj_2, ak_2, scale_2, | |
| mesh_filename_3, x_3, y_3, z_3, ai_3, aj_3, ak_3, scale_3, | |
| mesh_filename_4, x_4, y_4, z_4, ai_4, aj_4, ak_4, scale_4, | |
| mesh_filename_5, x_5, y_5, z_5, ai_5, aj_5, ak_5, scale_5, | |
| session_id], outputs=[initial_scene]) | |
| b2.click(infer_wrapper.infer, inputs=[language_command, session_id], outputs=output_scene) | |
| demo.queue(concurrency_count=10) | |
| demo.launch() |